diff --git a/clients/src/main/java/org/apache/kafka/common/network/MultiSend.java b/clients/src/main/java/org/apache/kafka/common/network/MultiSend.java index f77ff971e80..6b663609305 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/MultiSend.java +++ b/clients/src/main/java/org/apache/kafka/common/network/MultiSend.java @@ -22,32 +22,35 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.nio.channels.GatheringByteChannel; -import java.util.Iterator; -import java.util.List; +import java.util.Queue; /** * A set of composite sends, sent one after another */ - public class MultiSend implements Send { - private static final Logger log = LoggerFactory.getLogger(MultiSend.class); private final String dest; - private final Iterator sendsIterator; + private final Queue sendQueue; private final long size; private long totalWritten = 0; private Send current; - public MultiSend(String dest, List sends) { + /** + * Construct a MultiSend for the given destination from a queue of Send objects. The queue will be + * consumed as the MultiSend progresses (on completion, it will be empty). + */ + public MultiSend(String dest, Queue sends) { this.dest = dest; - this.sendsIterator = sends.iterator(); - nextSendOrDone(); + this.sendQueue = sends; + long size = 0; for (Send send : sends) size += send.size(); this.size = size; + + this.current = sendQueue.poll(); } @Override @@ -65,6 +68,15 @@ public class MultiSend implements Send { return current == null; } + // Visible for testing + int numResidentSends() { + int count = 0; + if (current != null) + count += 1; + count += sendQueue.size(); + return count; + } + @Override public long writeTo(GatheringByteChannel channel) throws IOException { if (completed()) @@ -77,7 +89,7 @@ public class MultiSend implements Send { totalWrittenPerCall += written; sendComplete = current.completed(); if (sendComplete) - nextSendOrDone(); + current = sendQueue.poll(); } while (!completed() && sendComplete); totalWritten += totalWrittenPerCall; @@ -91,10 +103,4 @@ public class MultiSend implements Send { return totalWrittenPerCall; } - private void nextSendOrDone() { - if (sendsIterator.hasNext()) - current = sendsIterator.next(); - else - current = null; - } } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java index 98c6be333e1..ed0f5a347dd 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java @@ -29,12 +29,14 @@ import org.apache.kafka.common.protocol.types.Struct; import org.apache.kafka.common.record.Records; import java.nio.ByteBuffer; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Queue; import static org.apache.kafka.common.protocol.CommonFields.ERROR_CODE; import static org.apache.kafka.common.protocol.CommonFields.PARTITION_ID; @@ -359,7 +361,7 @@ public class FetchResponse extends AbstractResponse { responseHeaderStruct.writeTo(buffer); buffer.rewind(); - List sends = new ArrayList<>(); + Queue sends = new ArrayDeque<>(); sends.add(new ByteBufferSend(dest, buffer)); addResponseData(responseBodyStruct, throttleTimeMs, dest, sends); return new MultiSend(dest, sends); @@ -393,7 +395,7 @@ public class FetchResponse extends AbstractResponse { return new FetchResponse(ApiKeys.FETCH.responseSchema(version).read(buffer)); } - private static void addResponseData(Struct struct, int throttleTimeMs, String dest, List sends) { + private static void addResponseData(Struct struct, int throttleTimeMs, String dest, Queue sends) { Object[] allTopicData = struct.getArray(RESPONSES_KEY_NAME); if (struct.hasField(ERROR_CODE)) { @@ -421,7 +423,7 @@ public class FetchResponse extends AbstractResponse { addTopicData(dest, sends, (Struct) topicData); } - private static void addTopicData(String dest, List sends, Struct topicData) { + private static void addTopicData(String dest, Queue sends, Struct topicData) { String topic = topicData.get(TOPIC_NAME); Object[] allPartitionData = topicData.getArray(PARTITIONS_KEY_NAME); @@ -436,7 +438,7 @@ public class FetchResponse extends AbstractResponse { addPartitionData(dest, sends, (Struct) partitionData); } - private static void addPartitionData(String dest, List sends, Struct partitionData) { + private static void addPartitionData(String dest, Queue sends, Struct partitionData) { Struct header = partitionData.getStruct(PARTITION_HEADER_KEY_NAME); Records records = partitionData.getRecords(RECORD_SET_KEY_NAME); diff --git a/clients/src/test/java/org/apache/kafka/common/network/MultiSendTest.java b/clients/src/test/java/org/apache/kafka/common/network/MultiSendTest.java new file mode 100644 index 00000000000..d2b2ef6c5d2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/MultiSendTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedList; +import java.util.Queue; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class MultiSendTest { + + @Test + public void testSendsFreedAfterWriting() throws IOException { + String dest = "1"; + int numChunks = 4; + int chunkSize = 32; + int totalSize = numChunks * chunkSize; + + Queue sends = new LinkedList<>(); + ByteBuffer[] chunks = new ByteBuffer[numChunks]; + + for (int i = 0; i < numChunks; i++) { + ByteBuffer buffer = ByteBuffer.wrap(TestUtils.randomBytes(chunkSize)); + chunks[i] = buffer; + sends.add(new ByteBufferSend(dest, buffer)); + } + + MultiSend send = new MultiSend(dest, sends); + assertEquals(totalSize, send.size()); + + for (int i = 0; i < numChunks; i++) { + assertEquals(numChunks - i, send.numResidentSends()); + NonOverflowingByteBufferChannel out = new NonOverflowingByteBufferChannel(chunkSize); + send.writeTo(out); + out.close(); + assertEquals(chunks[i], out.buffer()); + } + + assertEquals(0, send.numResidentSends()); + assertTrue(send.completed()); + } + + private static class NonOverflowingByteBufferChannel extends org.apache.kafka.common.requests.ByteBufferChannel { + + private NonOverflowingByteBufferChannel(long size) { + super(size); + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + // Instead of overflowing, this channel refuses additional writes once the buffer is full, + // which allows us to test the MultiSend behavior on a per-send basis. + if (!buffer().hasRemaining()) + return 0; + return super.write(srcs); + } + } + +}