From 14e426c9a820f992cc34f67c1f6e885c6e81ca53 Mon Sep 17 00:00:00 2001 From: Colin Patrick McCabe Date: Wed, 13 May 2020 15:26:50 -0700 Subject: [PATCH] MINOR: Add a duplicate() method to Message classes (#8556) Reviewers: Chia-Ping Tsai , Jason Gustafson --- .../apache/kafka/common/protocol/Message.java | 7 ++ .../kafka/common/protocol/MessageUtil.java | 9 ++ .../kafka/common/message/MessageTest.java | 16 +++ .../org/apache/kafka/message/FieldType.java | 6 +- .../kafka/message/MessageDataGenerator.java | 114 ++++++++++++++++-- 5 files changed, 137 insertions(+), 15 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Message.java b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java index 149fe11cfe2..3ff33044fec 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Message.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java @@ -148,4 +148,11 @@ public interface Message { * @return The raw tagged fields. */ List unknownTaggedFields(); + + /** + * Make a deep copy of the message. + * + * @return A copy of the message which does not share any mutable fields. + */ + Message duplicate(); } diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java b/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java index 1f09c53c471..21eb35e4f4a 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java @@ -154,4 +154,13 @@ public final class MessageUtil { } return node.asDouble(); } + + public static byte[] duplicate(byte[] array) { + if (array == null) { + return null; + } + byte[] newArray = new byte[array.length]; + System.arraycopy(array, 0, newArray, 0, array.length); + return newArray; + } } diff --git a/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java index 71c549683d7..e01fda981ad 100644 --- a/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java +++ b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java @@ -61,6 +61,7 @@ import java.util.function.Supplier; import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -620,14 +621,29 @@ public final class MessageTest { message.setMyNullableString("notNull"); message.setMyInt16((short) 3); message.setMyString("test string"); + SimpleExampleMessageData duplicate = message.duplicate(); + assertEquals(duplicate, message); + assertEquals(message, duplicate); + duplicate.setMyTaggedIntArray(Collections.singletonList(123)); + assertFalse(duplicate.equals(message)); + assertFalse(message.equals(duplicate)); testAllMessageRoundTripsFromVersion((short) 2, message); } private void testAllMessageRoundTrips(Message message) throws Exception { + testDuplication(message); testAllMessageRoundTripsFromVersion(message.lowestSupportedVersion(), message); } + private void testDuplication(Message message) { + Message duplicate = message.duplicate(); + assertEquals(duplicate, message); + assertEquals(message, duplicate); + assertEquals(duplicate.hashCode(), message.hashCode()); + assertEquals(message.hashCode(), duplicate.hashCode()); + } + private void testAllMessageRoundTripsBeforeVersion(short beforeVersion, Message message, Message expected) throws Exception { testAllMessageRoundTripsBetweenVersions((short) 0, beforeVersion, message, expected); } diff --git a/generator/src/main/java/org/apache/kafka/message/FieldType.java b/generator/src/main/java/org/apache/kafka/message/FieldType.java index af5a5a70809..5e0a4c7a745 100644 --- a/generator/src/main/java/org/apache/kafka/message/FieldType.java +++ b/generator/src/main/java/org/apache/kafka/message/FieldType.java @@ -20,7 +20,7 @@ package org.apache.kafka.message; import java.util.Optional; public interface FieldType { - String STRUCT_PREFIX = "[]"; + String ARRAY_PREFIX = "[]"; final class BoolFieldType implements FieldType { static final BoolFieldType INSTANCE = new BoolFieldType(); @@ -268,8 +268,8 @@ public interface FieldType { case BytesFieldType.NAME: return BytesFieldType.INSTANCE; default: - if (string.startsWith(STRUCT_PREFIX)) { - String elementTypeString = string.substring(STRUCT_PREFIX.length()); + if (string.startsWith(ARRAY_PREFIX)) { + String elementTypeString = string.substring(ARRAY_PREFIX.length()); if (elementTypeString.length() == 0) { throw new RuntimeException("Can't parse array type " + string + ". No element type found."); diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java index 2c9cc0f3004..90c8bbe3f07 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -112,6 +112,8 @@ public final class MessageDataGenerator { buffer.printf("%n"); generateClassHashCode(struct, isSetElement); buffer.printf("%n"); + generateClassDuplicate(className, struct); + buffer.printf("%n"); generateClassToString(className, struct); generateFieldAccessors(struct, isSetElement); buffer.printf("%n"); @@ -193,10 +195,16 @@ public final class MessageDataGenerator { collectionType(className), className); buffer.incrementIndent(); generateHashSetZeroArgConstructor(className); + buffer.printf("%n"); generateHashSetSizeArgConstructor(className); + buffer.printf("%n"); generateHashSetIteratorConstructor(className); + buffer.printf("%n"); generateHashSetFindMethod(className, struct); + buffer.printf("%n"); generateHashSetFindAllMethod(className, struct); + buffer.printf("%n"); + generateCollectionDuplicateMethod(className, struct); buffer.decrementIndent(); buffer.printf("}%n"); } @@ -207,7 +215,6 @@ public final class MessageDataGenerator { buffer.printf("super();%n"); buffer.decrementIndent(); buffer.printf("}%n"); - buffer.printf("%n"); } private void generateHashSetSizeArgConstructor(String className) { @@ -216,7 +223,6 @@ public final class MessageDataGenerator { buffer.printf("super(expectedNumElements);%n"); buffer.decrementIndent(); buffer.printf("}%n"); - buffer.printf("%n"); } private void generateHashSetIteratorConstructor(String className) { @@ -226,7 +232,6 @@ public final class MessageDataGenerator { buffer.printf("super(iterator);%n"); buffer.decrementIndent(); buffer.printf("}%n"); - buffer.printf("%n"); } private void generateHashSetFindMethod(String className, StructSpec struct) { @@ -239,7 +244,6 @@ public final class MessageDataGenerator { buffer.printf("return find(_key);%n"); buffer.decrementIndent(); buffer.printf("}%n"); - buffer.printf("%n"); } private void generateHashSetFindAllMethod(String className, StructSpec struct) { @@ -252,7 +256,6 @@ public final class MessageDataGenerator { buffer.printf("return findAll(_key);%n"); buffer.decrementIndent(); buffer.printf("}%n"); - buffer.printf("%n"); } private void generateKeyElement(String className, StructSpec struct) { @@ -273,6 +276,22 @@ public final class MessageDataGenerator { collect(Collectors.joining(", ")); } + private void generateCollectionDuplicateMethod(String className, StructSpec struct) { + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + buffer.printf("public %s duplicate() {%n", collectionType(className)); + buffer.incrementIndent(); + buffer.printf("%s _duplicate = new %s(size());%n", + collectionType(className), collectionType(className)); + buffer.printf("for (%s _element : this) {%n", className); + buffer.incrementIndent(); + buffer.printf("_duplicate.add(_element.duplicate());%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("return _duplicate;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + private void generateFieldDeclarations(StructSpec struct, boolean isSetElement) { for (FieldSpec field : struct.fields()) { generateFieldDeclaration(field); @@ -2112,6 +2131,83 @@ public final class MessageDataGenerator { } } + private void generateClassDuplicate(String className, StructSpec struct) { + buffer.printf("@Override%n"); + buffer.printf("public %s duplicate() {%n", className); + buffer.incrementIndent(); + buffer.printf("%s _duplicate = new %s();%n", className, className); + for (FieldSpec field : struct.fields()) { + generateFieldDuplicate(new Target(field, + field.camelCaseName(), + field.camelCaseName(), + input -> String.format("_duplicate.%s = %s", field.camelCaseName(), input))); + } + buffer.printf("return _duplicate;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldDuplicate(Target target) { + FieldSpec field = target.field(); + if ((field.type() instanceof FieldType.BoolFieldType) || + (field.type() instanceof FieldType.Int8FieldType) || + (field.type() instanceof FieldType.Int16FieldType) || + (field.type() instanceof FieldType.Int32FieldType) || + (field.type() instanceof FieldType.Int64FieldType) || + (field.type() instanceof FieldType.Float64FieldType) || + (field.type() instanceof FieldType.UUIDFieldType)) { + buffer.printf("%s;%n", target.assignmentStatement(target.sourceVariable())); + } else { + IsNullConditional cond = IsNullConditional.forName(target.sourceVariable()). + nullableVersions(target.field().nullableVersions()). + ifNull(() -> buffer.printf("%s;%n", target.assignmentStatement("null"))); + if (field.type().isBytes()) { + if (field.zeroCopy()) { + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s.duplicate()", target.sourceVariable())))); + } else { + cond.ifShouldNotBeNull(() -> { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.duplicate(%s)", + target.sourceVariable()))); + }); + } + } else if (field.type().isStruct()) { + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s.duplicate()", target.sourceVariable())))); + } else if (field.type().isString()) { + // Strings are immutable, so we don't need to duplicate them. + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + target.sourceVariable()))); + } else if (field.type().isArray()) { + cond.ifShouldNotBeNull(() -> { + String newArrayName = + String.format("new%s", field.capitalizedCamelCaseName()); + buffer.printf("%s %s = new %s(%s.size());%n", + fieldConcreteJavaType(field), newArrayName, + fieldConcreteJavaType(field), target.sourceVariable()); + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + buffer.printf("for (%s _element : %s) {%n", + getBoxedJavaType(arrayType.elementType()), target.sourceVariable()); + buffer.incrementIndent(); + generateFieldDuplicate(target.arrayElementTarget(input -> + String.format("%s.add(%s)", newArrayName, input))); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new%s", field.capitalizedCamelCaseName()))); + }); + } else { + throw new RuntimeException("Unhandled field type " + field.type()); + } + cond.generate(buffer); + } + } + private void generateClassToString(String className, StructSpec struct) { buffer.printf("@Override%n"); buffer.printf("public String toString() {%n"); @@ -2310,13 +2406,7 @@ public final class MessageDataGenerator { field.name() + ". The only valid default for an array field " + "is the empty array or null."); } - FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); - if (structRegistry.isStructArrayWithKeys(field)) { - return "new " + collectionType(arrayType.elementType().toString()) + "(0)"; - } else { - headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS); - return "new ArrayList<" + getBoxedJavaType(arrayType.elementType()) + ">()"; - } + return String.format("new %s(0)", fieldConcreteJavaType(field)); } else { throw new RuntimeException("Unsupported field type " + field.type()); }