diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java index fdce0644c72..a2ba480eaad 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java @@ -221,7 +221,10 @@ public class MockConsumer implements Consumer { @Override public OffsetAndMetadata committed(TopicPartition partition) { ensureNotClosed(); - return subscriptions.committed(partition); + if (subscriptions.isAssigned(partition)) { + return subscriptions.committed(partition); + } + return new OffsetAndMetadata(0); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index cfbd3a0f246..8b7aaeacf81 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -219,7 +219,6 @@ public class StreamThread extends Thread { private ThreadCache cache; private final TaskCreator taskCreator = new TaskCreator(); - private final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator(); final ConsumerRebalanceListener rebalanceListener = new ConsumerRebalanceListener() { @Override @@ -892,18 +891,7 @@ public class StreamThread extends Thread { newStandbyTasks.put(taskId, partitions); } - if (task != null) { - standbyTasks.put(taskId, task); - for (TopicPartition partition : partitions) { - standbyTasksByPartition.put(partition, task); - } - // collect checked pointed offsets to position the restore consumer - // this include all partitions from which we restore states - for (TopicPartition partition : task.checkpointedOffsets().keySet()) { - standbyTasksByPartition.put(partition, task); - } - checkpointedOffsets.putAll(task.checkpointedOffsets()); - } + updateStandByTaskMaps(checkpointedOffsets, taskId, partitions, task); } // destroy any remaining suspended tasks @@ -911,7 +899,7 @@ public class StreamThread extends Thread { // create all newly assigned standby tasks (guard against race condition with other thread via backoff and retry) // -> other thread will call removeSuspendedStandbyTasks(); eventually - standbyTaskCreator.retryWithBackoff(newStandbyTasks); + new StandbyTaskCreator(checkpointedOffsets).retryWithBackoff(newStandbyTasks); restoreConsumer.assign(new ArrayList<>(checkpointedOffsets.keySet())); @@ -926,6 +914,21 @@ public class StreamThread extends Thread { } } + private void updateStandByTaskMaps(final Map checkpointedOffsets, final TaskId taskId, final Set partitions, final StandbyTask task) { + if (task != null) { + standbyTasks.put(taskId, task); + for (TopicPartition partition : partitions) { + standbyTasksByPartition.put(partition, task); + } + // collect checked pointed offsets to position the restore consumer + // this include all partitions from which we restore states + for (TopicPartition partition : task.checkpointedOffsets().keySet()) { + standbyTasksByPartition.put(partition, task); + } + checkpointedOffsets.putAll(task.checkpointedOffsets()); + } + } + private void updateSuspendedTasks() { log.info("{} Updating suspended tasks to contain active tasks [{}]", logPrefix, activeTasks.keySet()); suspendedTasks.clear(); @@ -1209,11 +1212,11 @@ public class StreamThread extends Thread { } } - abstract void createTask(final TaskId id, final Collection partitions); + abstract void createTask(final TaskId id, final Set partitions); } class TaskCreator extends AbstractTaskCreator { - void createTask(final TaskId taskId, final Collection partitions) { + void createTask(final TaskId taskId, final Set partitions) { log.debug("{} creating new task {}", logPrefix, taskId); final StreamTask task = createStreamTask(taskId, partitions); @@ -1226,20 +1229,16 @@ public class StreamThread extends Thread { } class StandbyTaskCreator extends AbstractTaskCreator { - void createTask(final TaskId taskId, final Collection partitions) { - log.debug("{} creating new standby task {}", logPrefix, taskId); - final StandbyTask task = createStandbyTask(taskId, partitions); + private final Map checkpointedOffsets; - standbyTasks.put(taskId, task); + StandbyTaskCreator(final Map checkpointedOffsets) { + this.checkpointedOffsets = checkpointedOffsets; + } - for (TopicPartition partition : partitions) { - standbyTasksByPartition.put(partition, task); - } - // collect checked pointed offsets to position the restore consumer - // this include all partitions from which we restore states - for (TopicPartition partition : task.checkpointedOffsets().keySet()) { - standbyTasksByPartition.put(partition, task); - } + void createTask(final TaskId taskId, final Set partitions) { + log.debug("{} creating new standby task {}", logPrefix, taskId); + final StandbyTask task = createStandbyTask(taskId, partitions); + updateStandByTaskMaps(checkpointedOffsets, taskId, partitions, task); } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 13678a21948..1e4f883aae3 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.MockConsumer; import org.apache.kafka.clients.consumer.internals.PartitionAssignor; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.common.Cluster; @@ -30,6 +31,7 @@ import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStreamBuilder; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TopologyBuilder; import org.apache.kafka.streams.state.Stores; @@ -593,6 +595,81 @@ public class StreamThreadTest { assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer); } + @Test + public void shouldNotNullPointerWhenStandbyTasksAssignedAndNoStateStoresForTopology() throws Exception { + final TopologyBuilder builder = new TopologyBuilder(); + builder.setApplicationId("appId") + .addSource("name", "topic") + .addSink("out", "output"); + + + final StreamsConfig config = new StreamsConfig(configProps()); + final StreamThread thread = new StreamThread(builder, config, new MockClientSupplier(), applicationId, + clientId, processId, new Metrics(), new MockTime(), new StreamsMetadataState(builder)); + + thread.partitionAssignor(new StreamPartitionAssignor() { + @Override + Map> standbyTasks() { + return Collections.singletonMap(new TaskId(0, 0), Utils.mkSet(new TopicPartition("topic", 0))); + } + }); + + thread.rebalanceListener.onPartitionsRevoked(Collections.emptyList()); + thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList()); + } + + @Test + public void shouldInitializeRestoreConsumerWithOffsetsFromStandbyTasks() throws Exception { + final KStreamBuilder builder = new KStreamBuilder(); + builder.setApplicationId("appId"); + builder.stream("t1").groupByKey().count("count-one"); + builder.stream("t2").groupByKey().count("count-two"); + final StreamsConfig config = new StreamsConfig(configProps()); + final MockClientSupplier clientSupplier = new MockClientSupplier(); + + final StreamThread thread = new StreamThread(builder, config, clientSupplier, applicationId, + clientId, processId, new Metrics(), new MockTime(), new StreamsMetadataState(builder)); + + final MockConsumer restoreConsumer = clientSupplier.restoreConsumer; + restoreConsumer.updatePartitions("stream-thread-test-count-one-changelog", + Collections.singletonList(new PartitionInfo("stream-thread-test-count-one-changelog", + 0, + null, + new Node[0], + new Node[0]))); + restoreConsumer.updatePartitions("stream-thread-test-count-two-changelog", + Collections.singletonList(new PartitionInfo("stream-thread-test-count-two-changelog", + 0, + null, + new Node[0], + new Node[0]))); + + final Map> standbyTasks = new HashMap<>(); + final TopicPartition t1 = new TopicPartition("t1", 0); + standbyTasks.put(new TaskId(0, 0), Utils.mkSet(t1)); + + thread.partitionAssignor(new StreamPartitionAssignor() { + @Override + Map> standbyTasks() { + return standbyTasks; + } + }); + + thread.rebalanceListener.onPartitionsRevoked(Collections.emptyList()); + thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList()); + + assertThat(restoreConsumer.assignment(), equalTo(Utils.mkSet(new TopicPartition("stream-thread-test-count-one-changelog", 0)))); + + // assign an existing standby plus a new one + standbyTasks.put(new TaskId(1, 0), Utils.mkSet(new TopicPartition("t2", 0))); + thread.rebalanceListener.onPartitionsRevoked(Collections.emptyList()); + thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList()); + + assertThat(restoreConsumer.assignment(), equalTo(Utils.mkSet(new TopicPartition("stream-thread-test-count-one-changelog", 0), + new TopicPartition("stream-thread-test-count-two-changelog", 0)))); + + } + private void initPartitionGrouper(StreamsConfig config, StreamThread thread) { StreamPartitionAssignor partitionAssignor = new StreamPartitionAssignor();