diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java index bfdae6b37f8..0981291efd1 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java @@ -110,6 +110,9 @@ public abstract class AbstractTask { public abstract void close(); + public abstract void initTopology(); + public abstract void closeTopology(); + public abstract void commitOffsets(); /** diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java index 8c6078a10bc..6496b8852e4 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java @@ -170,4 +170,11 @@ public class PartitionGroup { queuesByTime.clear(); partitionQueues.clear(); } + + public void clear() { + queuesByTime.clear(); + for (RecordQueue queue : partitionQueues.values()) { + queue.clear(); + } + } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java index 5199c969a6f..44ef1460784 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java @@ -176,4 +176,11 @@ public class RecordQueue { public long timestamp() { return partitionTime; } + + /** + * Clear the fifo queue of its elements + */ + public void clear() { + fifoQueue.clear(); + } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java index ac8b0ffa680..4437a1955e7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java @@ -102,6 +102,16 @@ public class StandbyTask extends AbstractTask { //no-op } + @Override + public void initTopology() { + //no-op + } + + @Override + public void closeTopology() { + //no-op + } + @Override public void commitOffsets() { // no-op diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java index b993054ab78..0733e5a2b10 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java @@ -52,6 +52,7 @@ public class StreamTask extends AbstractTask implements Punctuator { private final PartitionGroup partitionGroup; private final PartitionGroup.RecordInfo recordInfo = new PartitionGroup.RecordInfo(); private final PunctuationQueue punctuationQueue; + private final Map partitionQueues; private final Map consumedOffsets; private final RecordCollector recordCollector; @@ -93,7 +94,7 @@ public class StreamTask extends AbstractTask implements Punctuator { // create queues for each assigned partition and associate them // to corresponding source nodes in the processor topology - Map partitionQueues = new HashMap<>(); + partitionQueues = new HashMap<>(); for (TopicPartition partition : partitions) { SourceNode source = topology.source(partition.topic()); @@ -119,16 +120,7 @@ public class StreamTask extends AbstractTask implements Punctuator { log.info("{} Initializing state stores", logPrefix); initializeStateStores(); - // initialize the task by initializing all its processor nodes in the topology - log.info("{} Initializing processor nodes of the topology", logPrefix); - for (ProcessorNode node : this.topology.processors()) { - this.currNode = node; - try { - node.init(this.processorContext); - } finally { - this.currNode = null; - } - } + initTopology(); ((ProcessorContextImpl) this.processorContext).initialized(); } @@ -328,15 +320,24 @@ public class StreamTask extends AbstractTask implements Punctuator { punctuationQueue.schedule(new PunctuationSchedule(currNode, interval)); } - /** - * @throws RuntimeException if an error happens during closing of processor nodes - */ @Override - public void close() { - log.debug("{} Closing processor topology", logPrefix); + public void initTopology() { + // initialize the task by initializing all its processor nodes in the topology + log.info("{} Initializing processor nodes of the topology", logPrefix); + for (ProcessorNode node : this.topology.processors()) { + this.currNode = node; + try { + node.init(this.processorContext); + } finally { + this.currNode = null; + } + } + } - this.partitionGroup.close(); - this.consumedOffsets.clear(); + @Override + public void closeTopology() { + + this.partitionGroup.clear(); // close the processors // make sure close() is called for each node even when there is a RuntimeException @@ -357,6 +358,18 @@ public class StreamTask extends AbstractTask implements Punctuator { } } + /** + * @throws RuntimeException if an error happens during closing of processor nodes + */ + @Override + public void close() { + log.debug("{} Closing processor topology", logPrefix); + + this.partitionGroup.close(); + this.consumedOffsets.clear(); + closeTopology(); + } + @Override protected Map recordCollectorOffsets() { return recordCollector.offsets(); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 7a04339c8db..05d712671f3 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -93,6 +93,8 @@ public class StreamThread extends Thread { private final Map activeTasksByPartition; private final Map standbyTasksByPartition; private final Set prevTasks; + private final Map suspendedTasks; + private final Map suspendedStandbyTasks; private final Time time; private final long pollTimeMs; private final long cleanTimeMs; @@ -119,7 +121,6 @@ public class StreamThread extends Thread { try { log.info("stream-thread [{}] New partitions [{}] assigned at the end of consumer rebalance.", StreamThread.this.getName(), assignment); - addStreamTasks(assignment); addStandbyTasks(); lastCleanMs = time.milliseconds(); // start the cleaning cycle @@ -136,16 +137,14 @@ public class StreamThread extends Thread { try { log.info("stream-thread [{}] partitions [{}] revoked at the beginning of consumer rebalance.", StreamThread.this.getName(), assignment); - initialized.set(false); lastCleanMs = Long.MAX_VALUE; // stop the cleaning cycle until partitions are assigned - shutdownTasksAndState(true); + // suspend active tasks + suspendTasksAndState(true); } catch (Throwable t) { rebalanceException = t; throw t; } finally { - // TODO: right now upon partition revocation, we always remove all the tasks; - // this behavior can be optimized to only remove affected tasks in the future streamsMetadataState.onChange(Collections.>emptyMap(), partitionAssignor.clusterMetadata()); removeStreamTasks(); removeStandbyTasks(); @@ -206,6 +205,8 @@ public class StreamThread extends Thread { this.activeTasksByPartition = new HashMap<>(); this.standbyTasksByPartition = new HashMap<>(); this.prevTasks = new HashSet<>(); + this.suspendedTasks = new HashMap<>(); + this.suspendedStandbyTasks = new HashMap<>(); // standby ktables this.standbyRecords = new HashMap<>(); @@ -291,7 +292,23 @@ public class StreamThread extends Thread { log.info("{} Stream thread shutdown complete", logPrefix); } + private void unAssignChangeLogPartitions(final boolean rethrowExceptions) { + try { + // un-assign the change log partitions + restoreConsumer.assign(Collections.emptyList()); + } catch (Exception e) { + log.error("{} Failed to un-assign change log partitions: ", logPrefix, e); + if (rethrowExceptions) { + throw e; + } + } + } + + private void shutdownTasksAndState(final boolean rethrowExceptions) { + log.debug("{} shutdownTasksAndState: shutting down all active tasks [{}] and standby tasks [{}]", logPrefix, + activeTasks.keySet(), standbyTasks.keySet()); + // Commit first as there may be cached records that have not been flushed yet. commitOffsets(rethrowExceptions); // Close all processors in topology order @@ -302,15 +319,33 @@ public class StreamThread extends Thread { producer.flush(); // Close all task state managers closeAllStateManagers(rethrowExceptions); - try { - // un-assign the change log partitions - restoreConsumer.assign(Collections.emptyList()); - } catch (Exception e) { - log.error("{} Failed to un-assign change log partitions: ", logPrefix, e); - if (rethrowExceptions) { - throw e; - } - } + // remove the changelog partitions from restore consumer + unAssignChangeLogPartitions(rethrowExceptions); + } + + + /** + * Similar to shutdownTasksAndState, however does not close the task managers, + * in the hope that soon the tasks will be assigned again + * @param rethrowExceptions + */ + private void suspendTasksAndState(final boolean rethrowExceptions) { + log.debug("{} suspendTasksAndState: suspending all active tasks [{}] and standby tasks [{}]", logPrefix, + activeTasks.keySet(), standbyTasks.keySet()); + + // Commit first as there may be cached records that have not been flushed yet. + commitOffsets(rethrowExceptions); + // Close all topology nodes + closeAllTasksTopologies(); + // flush state + flushAllState(rethrowExceptions); + // flush out any extra data sent during close + producer.flush(); + // remove the changelog partitions from restore consumer + unAssignChangeLogPartitions(rethrowExceptions); + + updateSuspendedTasks(); + } interface AbstractTaskAction { @@ -632,6 +667,27 @@ public class StreamThread extends Thread { return new StreamTask(id, applicationId, partitions, topology, consumer, producer, restoreConsumer, config, sensors, stateDirectory, cache); } + private StreamTask findMatchingSuspendedTask(final TaskId taskId, final Set partitions) { + if (suspendedTasks.containsKey(taskId)) { + final StreamTask task = suspendedTasks.get(taskId); + if (task.partitions.equals(partitions)) { + return task; + } + } + return null; + } + + private StandbyTask findMatchingSuspendedStandbyTask(final TaskId taskId, final Set partitions) { + if (suspendedStandbyTasks.containsKey(taskId)) { + final StandbyTask task = suspendedStandbyTasks.get(taskId); + if (task.partitions.equals(partitions)) { + return task; + } + } + return null; + } + + private void addStreamTasks(Collection assignment) { if (partitionAssignor == null) throw new IllegalStateException(logPrefix + " Partition assignor has not been initialized while adding stream tasks: this should not happen."); @@ -643,7 +699,15 @@ public class StreamThread extends Thread { if (assignment.containsAll(partitions)) { try { - StreamTask task = createStreamTask(taskId, partitions); + StreamTask task = findMatchingSuspendedTask(taskId, partitions); + if (task != null) { + log.debug("{} recycling old task {}", logPrefix, taskId); + suspendedTasks.remove(taskId); + task.initTopology(); + } else { + log.debug("{} creating new task {}", logPrefix, taskId); + task = createStreamTask(taskId, partitions); + } activeTasks.put(taskId, task); for (TopicPartition partition : partitions) @@ -656,6 +720,9 @@ public class StreamThread extends Thread { log.warn("{} Task {} owned partitions {} are not contained in the assignment {}", logPrefix, taskId, partitions, assignment); } } + + // finally destroy any remaining suspended tasks + removeSuspendedTasks(); } private StandbyTask createStandbyTask(TaskId id, Collection partitions) { @@ -682,7 +749,16 @@ public class StreamThread extends Thread { for (Map.Entry> entry : partitionAssignor.standbyTasks().entrySet()) { TaskId taskId = entry.getKey(); Set partitions = entry.getValue(); - StandbyTask task = createStandbyTask(taskId, partitions); + StandbyTask task = findMatchingSuspendedStandbyTask(taskId, partitions); + + if (task != null) { + log.debug("{} recycling old standby task {}", logPrefix, taskId); + suspendedStandbyTasks.remove(taskId); + task.initTopology(); + } else { + log.debug("{} creating new standby task {}", logPrefix, taskId); + task = createStandbyTask(taskId, partitions); + } if (task != null) { standbyTasks.put(taskId, task); for (TopicPartition partition : partitions) { @@ -696,6 +772,8 @@ public class StreamThread extends Thread { checkpointedOffsets.putAll(task.checkpointedOffsets()); } } + // finally destroy any remaining suspended tasks + removeSuspendedStandbyTasks(); restoreConsumer.assign(new ArrayList<>(checkpointedOffsets.keySet())); @@ -710,6 +788,13 @@ public class StreamThread extends Thread { } } + private void updateSuspendedTasks() { + log.info("{} Updating suspended tasks to contain active tasks [{}]", logPrefix, activeTasks.keySet()); + suspendedTasks.clear(); + suspendedTasks.putAll(activeTasks); + suspendedStandbyTasks.putAll(standbyTasks); + } + private void removeStreamTasks() { log.info("{} Removing all active tasks [{}]", logPrefix, activeTasks.keySet()); @@ -733,6 +818,40 @@ public class StreamThread extends Thread { standbyRecords.clear(); } + private void removeSuspendedTasks() { + log.info("{} Removing all suspended tasks [{}]", logPrefix, suspendedTasks.keySet()); + try { + // Close task and state manager + for (final AbstractTask task : suspendedTasks.values()) { + task.close(); + task.flushState(); + task.closeStateManager(); + // flush out any extra data sent during close + producer.flush(); + } + suspendedTasks.clear(); + } catch (Exception e) { + log.error("{} Failed to remove suspended tasks: ", logPrefix, e); + } + } + + private void removeSuspendedStandbyTasks() { + log.info("{} Removing all suspended standby tasks [{}]", logPrefix, suspendedStandbyTasks.keySet()); + try { + // Close task and state manager + for (final AbstractTask task : suspendedStandbyTasks.values()) { + task.close(); + task.flushState(); + task.closeStateManager(); + // flush out any extra data sent during close + producer.flush(); + } + suspendedStandbyTasks.clear(); + } catch (Exception e) { + log.error("{} Failed to remove suspended tasks: ", logPrefix, e); + } + } + private void closeAllTasks() { performOnAllTasks(new AbstractTaskAction() { @Override @@ -744,6 +863,17 @@ public class StreamThread extends Thread { }, "close", false); } + private void closeAllTasksTopologies() { + performOnAllTasks(new AbstractTaskAction() { + @Override + public void apply(final AbstractTask task) { + log.info("{} Closing a task's topology {}", StreamThread.this.logPrefix, task.id()); + task.closeTopology(); + sensors.taskDestructionSensor.record(); + } + }, "close", false); + } + /** * Produces a string representation contain useful information about a StreamThread. * This is useful in debugging scenarios. diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java index ab08dbe0f59..d672fd06e02 100644 --- a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java @@ -13,6 +13,7 @@ package org.apache.kafka.streams.integration; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.serialization.StringSerializer; @@ -43,6 +44,8 @@ import java.util.List; import java.util.Properties; import java.util.concurrent.ExecutionException; +import kafka.utils.MockTime; + import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; @@ -56,6 +59,7 @@ public class KStreamAggregationDedupIntegrationTest { public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + private final MockTime mockTime = CLUSTER.time; private static volatile int testNo = 0; private KStreamBuilder builder; private Properties streamsConfiguration; @@ -205,6 +209,44 @@ public class KStreamAggregationDedupIntegrationTest { )); } + @Test + public void shouldGroupByKey() throws Exception { + final long timestamp = mockTime.milliseconds(); + produceMessages(timestamp); + produceMessages(timestamp); + + stream.groupByKey(Serdes.Integer(), Serdes.String()) + .count(TimeWindows.of(500L), "count-windows") + .toStream(new KeyValueMapper, Long, String>() { + @Override + public String apply(final Windowed windowedKey, final Long value) { + return windowedKey.key() + "@" + windowedKey.window().start(); + } + }).to(Serdes.String(), Serdes.Long(), outputTopic); + + startStreams(); + + final List> results = receiveMessages( + new StringDeserializer(), + new LongDeserializer() + , 5); + Collections.sort(results, new Comparator>() { + @Override + public int compare(final KeyValue o1, final KeyValue o2) { + return KStreamAggregationDedupIntegrationTest.compare(o1, o2); + } + }); + + final long window = timestamp / 500 * 500; + assertThat(results, is(Arrays.asList( + KeyValue.pair("1@" + window, 2L), + KeyValue.pair("2@" + window, 2L), + KeyValue.pair("3@" + window, 2L), + KeyValue.pair("4@" + window, 2L), + KeyValue.pair("5@" + window, 2L) + ))); + + } private void produceMessages(long timestamp) @@ -261,4 +303,6 @@ public class KStreamAggregationDedupIntegrationTest { } + + } diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java index e5560c1b62f..bdbc1c8ffdd 100644 --- a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java @@ -43,10 +43,6 @@ import org.junit.After; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; -import org.junit.runners.Parameterized.Parameters; import java.io.IOException; import java.util.Arrays; @@ -59,7 +55,6 @@ import java.util.concurrent.ExecutionException; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; -@RunWith(Parameterized.class) public class KStreamAggregationIntegrationTest { private static final int NUM_BROKERS = 1; @@ -80,15 +75,6 @@ public class KStreamAggregationIntegrationTest { private Aggregator aggregator; private KStream stream; - @Parameter - public long cacheSizeBytes; - - //Single parameter, use Object[] - @Parameters - public static Object[] data() { - return new Object[] {0, 10 * 1024 * 1024L}; - } - @Before public void before() { testNo++; @@ -102,8 +88,7 @@ public class KStreamAggregationIntegrationTest { streamsConfiguration.put(StreamsConfig.ZOOKEEPER_CONNECT_CONFIG, CLUSTER.zKConnectString()); streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); - streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1); - streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, cacheSizeBytes); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); final KeyValueMapper mapper = MockKeyValueMapper.SelectValueMapper(); stream = builder.stream(Serdes.Integer(), Serdes.String(), streamOneInput); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java index 7cd0b8c9cb8..b1c19ba060a 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java @@ -83,6 +83,16 @@ public class AbstractTaskTest { } + @Override + public void initTopology() { + + } + + @Override + public void closeTopology() { + + } + @Override public void commitOffsets() { // do nothing diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 2f252e9fa60..e3aaab89156 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -154,6 +154,7 @@ public class StreamThreadTest { } } + @SuppressWarnings("unchecked") @Test public void testPartitionAssignmentChange() throws Exception { @@ -200,6 +201,8 @@ public class StreamThreadTest { expectedGroup2 = new HashSet<>(Arrays.asList(t1p2)); rebalanceListener.onPartitionsRevoked(revokedPartitions); + assertFalse(thread.tasks().containsKey(task1)); + assertEquals(0, thread.tasks().size()); rebalanceListener.onPartitionsAssigned(assignedPartitions); assertTrue(thread.tasks().containsKey(task2)); @@ -248,6 +251,20 @@ public class StreamThreadTest { assertEquals(expectedGroup2, thread.tasks().get(task4).partitions()); assertEquals(2, thread.tasks().size()); + revokedPartitions = assignedPartitions; + assignedPartitions = Arrays.asList(t1p1, t2p1, t3p1); + expectedGroup1 = new HashSet<>(Arrays.asList(t1p1)); + expectedGroup2 = new HashSet<>(Arrays.asList(t2p1, t3p1)); + + rebalanceListener.onPartitionsRevoked(revokedPartitions); + rebalanceListener.onPartitionsAssigned(assignedPartitions); + + assertTrue(thread.tasks().containsKey(task1)); + assertTrue(thread.tasks().containsKey(task4)); + assertEquals(expectedGroup1, thread.tasks().get(task1).partitions()); + assertEquals(expectedGroup2, thread.tasks().get(task4).partitions()); + assertEquals(2, thread.tasks().size()); + revokedPartitions = assignedPartitions; assignedPartitions = Collections.emptyList(); @@ -257,6 +274,7 @@ public class StreamThreadTest { assertTrue(thread.tasks().isEmpty()); } + @Test public void testMaybeClean() throws Exception { File baseDir = Files.createTempDirectory("test").toFile();