From cd15321e0d250253abb990af53e1f5624cf46b42 Mon Sep 17 00:00:00 2001 From: Grant Henke Date: Tue, 9 Feb 2016 10:16:57 -0800 Subject: [PATCH] KAFKA-3189: Kafka server returns UnknownServerException for inherited exceptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … exceptions Author: Grant Henke Reviewers: Jiangjie Qin , Ismael Juma , Ewen Cheslack-Postava Closes #856 from granthenke/inherited-errors --- .../apache/kafka/common/protocol/Errors.java | 13 ++++++-- .../kafka/common/protocol/ErrorsTest.java | 17 ++++++++++ .../unit/kafka/message/MessageTest.scala | 32 ++++++++++++------- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java index 8581544a434..4a2086954e1 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java @@ -192,10 +192,17 @@ public enum Errors { } /** - * Return the error instance associated with this exception (or UNKNOWN if there is none) + * Return the error instance associated with this exception or any of its superclasses (or UNKNOWN if there is none). + * If there are multiple matches in the class hierarchy, the first match starting from the bottom is used. */ public static Errors forException(Throwable t) { - Errors error = classToError.get(t.getClass()); - return error == null ? UNKNOWN : error; + Class clazz = t.getClass(); + while (clazz != null) { + Errors error = classToError.get(clazz); + if (error != null) + return error; + clazz = clazz.getSuperclass(); + } + return UNKNOWN; } } diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java index b511b4bd051..2d96e587a60 100644 --- a/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java +++ b/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java @@ -24,6 +24,7 @@ import java.util.HashSet; import java.util.Set; import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.TimeoutException; import org.junit.Test; public class ErrorsTest { @@ -60,4 +61,20 @@ public class ErrorsTest { assertNull("The NONE error should not have an exception", Errors.NONE.exception()); } + @Test + public void testForExceptionInheritance() { + class ExtendedTimeoutException extends TimeoutException { } + + Errors expectedError = Errors.forException(new TimeoutException()); + Errors actualError = Errors.forException(new ExtendedTimeoutException()); + + assertEquals("forException should match super classes", expectedError, actualError); + } + + @Test + public void testForExceptionDefault() { + Errors error = Errors.forException(new ApiException()); + assertEquals("forException should default to unknown", Errors.UNKNOWN, error); + } + } diff --git a/core/src/test/scala/unit/kafka/message/MessageTest.scala b/core/src/test/scala/unit/kafka/message/MessageTest.scala index 3c12d13f5e5..1755633bc3c 100755 --- a/core/src/test/scala/unit/kafka/message/MessageTest.scala +++ b/core/src/test/scala/unit/kafka/message/MessageTest.scala @@ -5,7 +5,7 @@ * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software @@ -19,6 +19,8 @@ package kafka.message import java.nio._ import java.util.HashMap +import org.apache.kafka.common.protocol.Errors + import scala.collection._ import org.junit.Assert._ import org.scalatest.junit.JUnitSuite @@ -27,15 +29,15 @@ import kafka.utils.TestUtils import kafka.utils.CoreUtils import org.apache.kafka.common.utils.Utils -case class MessageTestVal(val key: Array[Byte], - val payload: Array[Byte], - val codec: CompressionCodec, +case class MessageTestVal(val key: Array[Byte], + val payload: Array[Byte], + val codec: CompressionCodec, val message: Message) class MessageTest extends JUnitSuite { - + var messages = new mutable.ArrayBuffer[MessageTestVal]() - + @Before def setUp(): Unit = { val keys = Array(null, "key".getBytes, "".getBytes) @@ -44,7 +46,7 @@ class MessageTest extends JUnitSuite { for(k <- keys; v <- vals; codec <- codecs) messages += new MessageTestVal(k, v, codec, new Message(v, k, codec)) } - + @Test def testFieldValues { for(v <- messages) { @@ -73,7 +75,7 @@ class MessageTest extends JUnitSuite { assertFalse("Message with invalid checksum should be invalid", v.message.isValid) } } - + @Test def testEquality() { for(v <- messages) { @@ -84,7 +86,7 @@ class MessageTest extends JUnitSuite { assertTrue("Should equal another message with the same content.", v.message.equals(copy)) } } - + @Test def testIsHashable() { // this is silly, but why not @@ -94,6 +96,14 @@ class MessageTest extends JUnitSuite { for(v <- messages) assertEquals(v.message, m.get(v.message)) } - + + @Test + def testExceptionMapping() { + val expected = Errors.CORRUPT_MESSAGE + val actual = Errors.forException(new InvalidMessageException()) + + assertEquals("InvalidMessageException should map to a corrupt message error", expected, actual) + } + } - +