Browse Source

MINOR: Add a duplicate() method to Message classes (#8556)

Reviewers: Chia-Ping Tsai <chia7712@gmail.com>, Jason Gustafson <jason@confluent.io>
pull/8659/head
Colin Patrick McCabe 5 years ago committed by GitHub
parent
commit
14e426c9a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      clients/src/main/java/org/apache/kafka/common/protocol/Message.java
  2. 9
      clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java
  3. 16
      clients/src/test/java/org/apache/kafka/common/message/MessageTest.java
  4. 6
      generator/src/main/java/org/apache/kafka/message/FieldType.java
  5. 114
      generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java

7
clients/src/main/java/org/apache/kafka/common/protocol/Message.java

@ -148,4 +148,11 @@ public interface Message { @@ -148,4 +148,11 @@ public interface Message {
* @return The raw tagged fields.
*/
List<RawTaggedField> unknownTaggedFields();
/**
* Make a deep copy of the message.
*
* @return A copy of the message which does not share any mutable fields.
*/
Message duplicate();
}

9
clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java

@ -154,4 +154,13 @@ public final class MessageUtil { @@ -154,4 +154,13 @@ public final class MessageUtil {
}
return node.asDouble();
}
public static byte[] duplicate(byte[] array) {
if (array == null) {
return null;
}
byte[] newArray = new byte[array.length];
System.arraycopy(array, 0, newArray, 0, array.length);
return newArray;
}
}

16
clients/src/test/java/org/apache/kafka/common/message/MessageTest.java

@ -61,6 +61,7 @@ import java.util.function.Supplier; @@ -61,6 +61,7 @@ import java.util.function.Supplier;
import static java.util.Collections.singletonList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -620,14 +621,29 @@ public final class MessageTest { @@ -620,14 +621,29 @@ public final class MessageTest {
message.setMyNullableString("notNull");
message.setMyInt16((short) 3);
message.setMyString("test string");
SimpleExampleMessageData duplicate = message.duplicate();
assertEquals(duplicate, message);
assertEquals(message, duplicate);
duplicate.setMyTaggedIntArray(Collections.singletonList(123));
assertFalse(duplicate.equals(message));
assertFalse(message.equals(duplicate));
testAllMessageRoundTripsFromVersion((short) 2, message);
}
private void testAllMessageRoundTrips(Message message) throws Exception {
testDuplication(message);
testAllMessageRoundTripsFromVersion(message.lowestSupportedVersion(), message);
}
private void testDuplication(Message message) {
Message duplicate = message.duplicate();
assertEquals(duplicate, message);
assertEquals(message, duplicate);
assertEquals(duplicate.hashCode(), message.hashCode());
assertEquals(message.hashCode(), duplicate.hashCode());
}
private void testAllMessageRoundTripsBeforeVersion(short beforeVersion, Message message, Message expected) throws Exception {
testAllMessageRoundTripsBetweenVersions((short) 0, beforeVersion, message, expected);
}

6
generator/src/main/java/org/apache/kafka/message/FieldType.java

@ -20,7 +20,7 @@ package org.apache.kafka.message; @@ -20,7 +20,7 @@ package org.apache.kafka.message;
import java.util.Optional;
public interface FieldType {
String STRUCT_PREFIX = "[]";
String ARRAY_PREFIX = "[]";
final class BoolFieldType implements FieldType {
static final BoolFieldType INSTANCE = new BoolFieldType();
@ -268,8 +268,8 @@ public interface FieldType { @@ -268,8 +268,8 @@ public interface FieldType {
case BytesFieldType.NAME:
return BytesFieldType.INSTANCE;
default:
if (string.startsWith(STRUCT_PREFIX)) {
String elementTypeString = string.substring(STRUCT_PREFIX.length());
if (string.startsWith(ARRAY_PREFIX)) {
String elementTypeString = string.substring(ARRAY_PREFIX.length());
if (elementTypeString.length() == 0) {
throw new RuntimeException("Can't parse array type " + string +
". No element type found.");

114
generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java

@ -112,6 +112,8 @@ public final class MessageDataGenerator { @@ -112,6 +112,8 @@ public final class MessageDataGenerator {
buffer.printf("%n");
generateClassHashCode(struct, isSetElement);
buffer.printf("%n");
generateClassDuplicate(className, struct);
buffer.printf("%n");
generateClassToString(className, struct);
generateFieldAccessors(struct, isSetElement);
buffer.printf("%n");
@ -193,10 +195,16 @@ public final class MessageDataGenerator { @@ -193,10 +195,16 @@ public final class MessageDataGenerator {
collectionType(className), className);
buffer.incrementIndent();
generateHashSetZeroArgConstructor(className);
buffer.printf("%n");
generateHashSetSizeArgConstructor(className);
buffer.printf("%n");
generateHashSetIteratorConstructor(className);
buffer.printf("%n");
generateHashSetFindMethod(className, struct);
buffer.printf("%n");
generateHashSetFindAllMethod(className, struct);
buffer.printf("%n");
generateCollectionDuplicateMethod(className, struct);
buffer.decrementIndent();
buffer.printf("}%n");
}
@ -207,7 +215,6 @@ public final class MessageDataGenerator { @@ -207,7 +215,6 @@ public final class MessageDataGenerator {
buffer.printf("super();%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%n");
}
private void generateHashSetSizeArgConstructor(String className) {
@ -216,7 +223,6 @@ public final class MessageDataGenerator { @@ -216,7 +223,6 @@ public final class MessageDataGenerator {
buffer.printf("super(expectedNumElements);%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%n");
}
private void generateHashSetIteratorConstructor(String className) {
@ -226,7 +232,6 @@ public final class MessageDataGenerator { @@ -226,7 +232,6 @@ public final class MessageDataGenerator {
buffer.printf("super(iterator);%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%n");
}
private void generateHashSetFindMethod(String className, StructSpec struct) {
@ -239,7 +244,6 @@ public final class MessageDataGenerator { @@ -239,7 +244,6 @@ public final class MessageDataGenerator {
buffer.printf("return find(_key);%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%n");
}
private void generateHashSetFindAllMethod(String className, StructSpec struct) {
@ -252,7 +256,6 @@ public final class MessageDataGenerator { @@ -252,7 +256,6 @@ public final class MessageDataGenerator {
buffer.printf("return findAll(_key);%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%n");
}
private void generateKeyElement(String className, StructSpec struct) {
@ -273,6 +276,22 @@ public final class MessageDataGenerator { @@ -273,6 +276,22 @@ public final class MessageDataGenerator {
collect(Collectors.joining(", "));
}
private void generateCollectionDuplicateMethod(String className, StructSpec struct) {
headerGenerator.addImport(MessageGenerator.LIST_CLASS);
buffer.printf("public %s duplicate() {%n", collectionType(className));
buffer.incrementIndent();
buffer.printf("%s _duplicate = new %s(size());%n",
collectionType(className), collectionType(className));
buffer.printf("for (%s _element : this) {%n", className);
buffer.incrementIndent();
buffer.printf("_duplicate.add(_element.duplicate());%n");
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("return _duplicate;%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateFieldDeclarations(StructSpec struct, boolean isSetElement) {
for (FieldSpec field : struct.fields()) {
generateFieldDeclaration(field);
@ -2112,6 +2131,83 @@ public final class MessageDataGenerator { @@ -2112,6 +2131,83 @@ public final class MessageDataGenerator {
}
}
private void generateClassDuplicate(String className, StructSpec struct) {
buffer.printf("@Override%n");
buffer.printf("public %s duplicate() {%n", className);
buffer.incrementIndent();
buffer.printf("%s _duplicate = new %s();%n", className, className);
for (FieldSpec field : struct.fields()) {
generateFieldDuplicate(new Target(field,
field.camelCaseName(),
field.camelCaseName(),
input -> String.format("_duplicate.%s = %s", field.camelCaseName(), input)));
}
buffer.printf("return _duplicate;%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateFieldDuplicate(Target target) {
FieldSpec field = target.field();
if ((field.type() instanceof FieldType.BoolFieldType) ||
(field.type() instanceof FieldType.Int8FieldType) ||
(field.type() instanceof FieldType.Int16FieldType) ||
(field.type() instanceof FieldType.Int32FieldType) ||
(field.type() instanceof FieldType.Int64FieldType) ||
(field.type() instanceof FieldType.Float64FieldType) ||
(field.type() instanceof FieldType.UUIDFieldType)) {
buffer.printf("%s;%n", target.assignmentStatement(target.sourceVariable()));
} else {
IsNullConditional cond = IsNullConditional.forName(target.sourceVariable()).
nullableVersions(target.field().nullableVersions()).
ifNull(() -> buffer.printf("%s;%n", target.assignmentStatement("null")));
if (field.type().isBytes()) {
if (field.zeroCopy()) {
cond.ifShouldNotBeNull(() ->
buffer.printf("%s;%n", target.assignmentStatement(
String.format("%s.duplicate()", target.sourceVariable()))));
} else {
cond.ifShouldNotBeNull(() -> {
headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS);
buffer.printf("%s;%n", target.assignmentStatement(
String.format("MessageUtil.duplicate(%s)",
target.sourceVariable())));
});
}
} else if (field.type().isStruct()) {
cond.ifShouldNotBeNull(() ->
buffer.printf("%s;%n", target.assignmentStatement(
String.format("%s.duplicate()", target.sourceVariable()))));
} else if (field.type().isString()) {
// Strings are immutable, so we don't need to duplicate them.
cond.ifShouldNotBeNull(() ->
buffer.printf("%s;%n", target.assignmentStatement(
target.sourceVariable())));
} else if (field.type().isArray()) {
cond.ifShouldNotBeNull(() -> {
String newArrayName =
String.format("new%s", field.capitalizedCamelCaseName());
buffer.printf("%s %s = new %s(%s.size());%n",
fieldConcreteJavaType(field), newArrayName,
fieldConcreteJavaType(field), target.sourceVariable());
FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type();
buffer.printf("for (%s _element : %s) {%n",
getBoxedJavaType(arrayType.elementType()), target.sourceVariable());
buffer.incrementIndent();
generateFieldDuplicate(target.arrayElementTarget(input ->
String.format("%s.add(%s)", newArrayName, input)));
buffer.decrementIndent();
buffer.printf("}%n");
buffer.printf("%s;%n", target.assignmentStatement(
String.format("new%s", field.capitalizedCamelCaseName())));
});
} else {
throw new RuntimeException("Unhandled field type " + field.type());
}
cond.generate(buffer);
}
}
private void generateClassToString(String className, StructSpec struct) {
buffer.printf("@Override%n");
buffer.printf("public String toString() {%n");
@ -2310,13 +2406,7 @@ public final class MessageDataGenerator { @@ -2310,13 +2406,7 @@ public final class MessageDataGenerator {
field.name() + ". The only valid default for an array field " +
"is the empty array or null.");
}
FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type();
if (structRegistry.isStructArrayWithKeys(field)) {
return "new " + collectionType(arrayType.elementType().toString()) + "(0)";
} else {
headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS);
return "new ArrayList<" + getBoxedJavaType(arrayType.elementType()) + ">()";
}
return String.format("new %s(0)", fieldConcreteJavaType(field));
} else {
throw new RuntimeException("Unsupported field type " + field.type());
}

Loading…
Cancel
Save