Browse Source

KAFKA-15326: [10/N] Integrate processing thread (#14193)

- Introduce a new internal config flag to enable processing threads
- If enabled, create a scheduling task manager inside the normal task manager (renamings will be added on top of this), and use it from the stream thread
- All operations inside the task manager that change task state, lock the corresponding tasks if processing threads are enabled.
- Adds a new abstract class AbstractPartitionGroup. We can modify the underlying implementation depending on the synchronization requirements. PartitionGroup is the unsynchronized subclass that is going to be used by the original code path. The processing thread code path uses a trivially synchronized SynchronizedPartitionGroup that uses object monitors. Further down the road, there is the opportunity to implement a weakly synchronized alternative. The details are complex, but since the implementation is essentially a queue + some other things, it should be feasible to implement this lock-free.
- Refactorings in StreamThreadTest: Make all tests use the thread member variable and add tearDown in order avoid thread leaks and simplify debugging. Make the test parameterized on two internal flags: state updater enabled and processing threads enabled. Use JUnit's assume to disable all tests that do not apply.
Enable some integration tests with processing threads enabled.

Reviewer: Bruno Cadonna <bruno@confluent.io>
pull/13873/merge
Lucas Brutschy 1 year ago committed by GitHub
parent
commit
d144b7ee38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
  2. 11
      clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java
  3. 7
      streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
  4. 85
      streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
  5. 12
      streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
  6. 49
      streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
  7. 42
      streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
  8. 131
      streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
  9. 101
      streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
  10. 106
      streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
  11. 5
      streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
  12. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
  13. 11
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskExecutor.java
  14. 6
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
  15. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/TaskManager.java
  16. 2
      streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java
  17. 31
      streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
  18. 13
      streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
  19. 1
      streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
  20. 26
      streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
  21. 32
      streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
  22. 838
      streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
  23. 201
      streams/src/test/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroupTest.java
  24. 124
      streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
  25. 5
      streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
  26. 4
      streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java

27
clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java

@ -160,7 +160,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override @Override
public void initTransactions() { public void initTransactions() {
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
if (this.transactionInitialized) { if (this.transactionInitialized) {
throw new IllegalStateException("MockProducer has already been initialized for transactions."); throw new IllegalStateException("MockProducer has already been initialized for transactions.");
} }
@ -176,7 +177,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override @Override
public void beginTransaction() throws ProducerFencedException { public void beginTransaction() throws ProducerFencedException {
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized(); verifyTransactionsInitialized();
if (this.beginTransactionException != null) { if (this.beginTransactionException != null) {
@ -205,7 +207,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
public void sendOffsetsToTransaction(Map<TopicPartition, OffsetAndMetadata> offsets, public void sendOffsetsToTransaction(Map<TopicPartition, OffsetAndMetadata> offsets,
ConsumerGroupMetadata groupMetadata) throws ProducerFencedException { ConsumerGroupMetadata groupMetadata) throws ProducerFencedException {
Objects.requireNonNull(groupMetadata); Objects.requireNonNull(groupMetadata);
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized(); verifyTransactionsInitialized();
verifyTransactionInFlight(); verifyTransactionInFlight();
@ -224,7 +227,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override @Override
public void commitTransaction() throws ProducerFencedException { public void commitTransaction() throws ProducerFencedException {
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized(); verifyTransactionsInitialized();
verifyTransactionInFlight(); verifyTransactionInFlight();
@ -249,7 +253,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override @Override
public void abortTransaction() throws ProducerFencedException { public void abortTransaction() throws ProducerFencedException {
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized(); verifyTransactionsInitialized();
verifyTransactionInFlight(); verifyTransactionInFlight();
@ -265,10 +270,13 @@ public class MockProducer<K, V> implements Producer<K, V> {
this.transactionInFlight = false; this.transactionInFlight = false;
} }
private synchronized void verifyProducerState() { private synchronized void verifyNotClosed() {
if (this.closed) { if (this.closed) {
throw new IllegalStateException("MockProducer is already closed."); throw new IllegalStateException("MockProducer is already closed.");
} }
}
private synchronized void verifyNotFenced() {
if (this.producerFenced) { if (this.producerFenced) {
throw new ProducerFencedException("MockProducer is fenced."); throw new ProducerFencedException("MockProducer is fenced.");
} }
@ -288,7 +296,7 @@ public class MockProducer<K, V> implements Producer<K, V> {
/** /**
* Adds the record to the list of sent records. The {@link RecordMetadata} returned will be immediately satisfied. * Adds the record to the list of sent records. The {@link RecordMetadata} returned will be immediately satisfied.
* *
* @see #history() * @see #history()
*/ */
@Override @Override
@ -362,7 +370,7 @@ public class MockProducer<K, V> implements Producer<K, V> {
} }
public synchronized void flush() { public synchronized void flush() {
verifyProducerState(); verifyNotClosed();
if (this.flushException != null) { if (this.flushException != null) {
throw this.flushException; throw this.flushException;
@ -415,7 +423,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
} }
public synchronized void fenceProducer() { public synchronized void fenceProducer() {
verifyProducerState(); verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized(); verifyTransactionsInitialized();
this.producerFenced = true; this.producerFenced = true;
} }

11
clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java

@ -40,6 +40,7 @@ import java.util.concurrent.Future;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -702,7 +703,15 @@ public class MockProducerTest {
producer.close(); producer.close();
assertThrows(IllegalStateException.class, producer::flush); assertThrows(IllegalStateException.class, producer::flush);
} }
@Test
public void shouldNotThrowOnFlushProducerIfProducerIsFenced() {
buildMockProducer(true);
producer.initTransactions();
producer.fenceProducer();
assertDoesNotThrow(producer::flush);
}
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void shouldThrowClassCastException() { public void shouldThrowClassCastException() {

7
streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java

@ -1200,6 +1200,13 @@ public class StreamsConfig extends AbstractConfig {
public static boolean getStateUpdaterEnabled(final Map<String, Object> configs) { public static boolean getStateUpdaterEnabled(final Map<String, Object> configs) {
return InternalConfig.getBoolean(configs, InternalConfig.STATE_UPDATER_ENABLED, true); return InternalConfig.getBoolean(configs, InternalConfig.STATE_UPDATER_ENABLED, true);
} }
// Private API to enable processing threads (i.e. polling is decoupled from processing)
public static final String PROCESSING_THREADS_ENABLED = "__processing.threads.enabled__";
public static boolean getProcessingThreadsEnabled(final Map<String, Object> configs) {
return InternalConfig.getBoolean(configs, InternalConfig.PROCESSING_THREADS_ENABLED, false);
}
public static boolean getBoolean(final Map<String, Object> configs, final String key, final boolean defaultValue) { public static boolean getBoolean(final Map<String, Object> configs, final String key, final boolean defaultValue) {
final Object value = configs.getOrDefault(key, defaultValue); final Object value = configs.getOrDefault(key, defaultValue);

85
streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import java.util.Set;
import java.util.function.Function;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
abstract class AbstractPartitionGroup {
abstract boolean readyToProcess(long wallClockTime);
// creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions
abstract void updatePartitions(Set<TopicPartition> inputPartitions, Function<TopicPartition, RecordQueue> recordQueueCreator);
abstract void setPartitionTime(TopicPartition partition, long partitionTime);
/**
* Get the next record and queue
*
* @return StampedRecord
*/
abstract StampedRecord nextRecord(RecordInfo info, long wallClockTime);
/**
* Adds raw records to this partition group
*
* @param partition the partition
* @param rawRecords the raw records
* @return the queue size for the partition
*/
abstract int addRawRecords(TopicPartition partition, Iterable<ConsumerRecord<byte[], byte[]>> rawRecords);
abstract long partitionTimestamp(final TopicPartition partition);
/**
* Return the stream-time of this partition group defined as the largest timestamp seen across all partitions
*/
abstract long streamTime();
abstract Long headRecordOffset(final TopicPartition partition);
abstract int numBuffered();
abstract int numBuffered(TopicPartition tp);
abstract void clear();
abstract void updateLags();
abstract void close();
abstract Set<TopicPartition> partitions();
static class RecordInfo {
RecordQueue queue;
ProcessorNode<?, ?, ?, ?> node() {
return queue.source();
}
TopicPartition partition() {
return queue.partition();
}
RecordQueue queue() {
return queue;
}
}
}

12
streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java

@ -66,6 +66,7 @@ class ActiveTaskCreator {
private final Map<TaskId, StreamsProducer> taskProducers; private final Map<TaskId, StreamsProducer> taskProducers;
private final ProcessingMode processingMode; private final ProcessingMode processingMode;
private final boolean stateUpdaterEnabled; private final boolean stateUpdaterEnabled;
private final boolean processingThreadsEnabled;
ActiveTaskCreator(final TopologyMetadata topologyMetadata, ActiveTaskCreator(final TopologyMetadata topologyMetadata,
final StreamsConfig applicationConfig, final StreamsConfig applicationConfig,
@ -78,7 +79,9 @@ class ActiveTaskCreator {
final String threadId, final String threadId,
final UUID processId, final UUID processId,
final Logger log, final Logger log,
final boolean stateUpdaterEnabled) { final boolean stateUpdaterEnabled,
final boolean processingThreadsEnabled
) {
this.topologyMetadata = topologyMetadata; this.topologyMetadata = topologyMetadata;
this.applicationConfig = applicationConfig; this.applicationConfig = applicationConfig;
this.streamsMetrics = streamsMetrics; this.streamsMetrics = streamsMetrics;
@ -90,6 +93,7 @@ class ActiveTaskCreator {
this.threadId = threadId; this.threadId = threadId;
this.log = log; this.log = log;
this.stateUpdaterEnabled = stateUpdaterEnabled; this.stateUpdaterEnabled = stateUpdaterEnabled;
this.processingThreadsEnabled = processingThreadsEnabled;
createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics); createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
processingMode = processingMode(applicationConfig); processingMode = processingMode(applicationConfig);
@ -242,7 +246,8 @@ class ActiveTaskCreator {
standbyTask.stateMgr, standbyTask.stateMgr,
recordCollector, recordCollector,
standbyTask.processorContext, standbyTask.processorContext,
standbyTask.logContext standbyTask.logContext,
processingThreadsEnabled
); );
log.trace("Created active task {} from recycled standby task with assigned partitions {}", task.id, inputPartitions); log.trace("Created active task {} from recycled standby task with assigned partitions {}", task.id, inputPartitions);
@ -272,7 +277,8 @@ class ActiveTaskCreator {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
processingThreadsEnabled
); );
log.trace("Created active task {} with assigned partitions {}", taskId, inputPartitions); log.trace("Created active task {} with assigned partitions {}", taskId, inputPartitions);

49
streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java

@ -56,7 +56,7 @@ import java.util.function.Function;
* As a consequence of the definition, the PartitionGroup's stream-time is non-decreasing * As a consequence of the definition, the PartitionGroup's stream-time is non-decreasing
* (i.e., it increases or stays the same over time). * (i.e., it increases or stays the same over time).
*/ */
public class PartitionGroup { class PartitionGroup extends AbstractPartitionGroup {
private final Logger logger; private final Logger logger;
private final Map<TopicPartition, RecordQueue> partitionQueues; private final Map<TopicPartition, RecordQueue> partitionQueues;
@ -72,22 +72,6 @@ public class PartitionGroup {
private final Map<TopicPartition, Long> idlePartitionDeadlines = new HashMap<>(); private final Map<TopicPartition, Long> idlePartitionDeadlines = new HashMap<>();
private final Map<TopicPartition, Long> fetchedLags = new HashMap<>(); private final Map<TopicPartition, Long> fetchedLags = new HashMap<>();
static class RecordInfo {
RecordQueue queue;
ProcessorNode<?, ?, ?, ?> node() {
return queue.source();
}
TopicPartition partition() {
return queue.partition();
}
RecordQueue queue() {
return queue;
}
}
PartitionGroup(final LogContext logContext, PartitionGroup(final LogContext logContext,
final Map<TopicPartition, RecordQueue> partitionQueues, final Map<TopicPartition, RecordQueue> partitionQueues,
final Function<TopicPartition, OptionalLong> lagProvider, final Function<TopicPartition, OptionalLong> lagProvider,
@ -106,7 +90,8 @@ public class PartitionGroup {
streamTime = RecordQueue.UNKNOWN; streamTime = RecordQueue.UNKNOWN;
} }
public boolean readyToProcess(final long wallClockTime) { @Override
boolean readyToProcess(final long wallClockTime) {
if (maxTaskIdleMs == StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) { if (maxTaskIdleMs == StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) {
if (logger.isTraceEnabled() && !allBuffered && totalBuffered > 0) { if (logger.isTraceEnabled() && !allBuffered && totalBuffered > 0) {
final Set<TopicPartition> bufferedPartitions = new HashSet<>(); final Set<TopicPartition> bufferedPartitions = new HashSet<>();
@ -209,7 +194,7 @@ public class PartitionGroup {
} }
} }
// visible for testing @Override
long partitionTimestamp(final TopicPartition partition) { long partitionTimestamp(final TopicPartition partition) {
final RecordQueue queue = partitionQueues.get(partition); final RecordQueue queue = partitionQueues.get(partition);
if (queue == null) { if (queue == null) {
@ -219,6 +204,7 @@ public class PartitionGroup {
} }
// creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions // creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions
@Override
void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) { void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
final Set<TopicPartition> removedPartitions = new HashSet<>(); final Set<TopicPartition> removedPartitions = new HashSet<>();
final Set<TopicPartition> newInputPartitions = new HashSet<>(inputPartitions); final Set<TopicPartition> newInputPartitions = new HashSet<>(inputPartitions);
@ -241,6 +227,7 @@ public class PartitionGroup {
allBuffered = allBuffered && newInputPartitions.isEmpty(); allBuffered = allBuffered && newInputPartitions.isEmpty();
} }
@Override
void setPartitionTime(final TopicPartition partition, final long partitionTime) { void setPartitionTime(final TopicPartition partition, final long partitionTime) {
final RecordQueue queue = partitionQueues.get(partition); final RecordQueue queue = partitionQueues.get(partition);
if (queue == null) { if (queue == null) {
@ -252,11 +239,7 @@ public class PartitionGroup {
queue.setPartitionTime(partitionTime); queue.setPartitionTime(partitionTime);
} }
/** @Override
* Get the next record and queue
*
* @return StampedRecord
*/
StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) { StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) {
StampedRecord record = null; StampedRecord record = null;
@ -290,13 +273,7 @@ public class PartitionGroup {
return record; return record;
} }
/** @Override
* Adds raw records to this partition group
*
* @param partition the partition
* @param rawRecords the raw records
* @return the queue size for the partition
*/
int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) { int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
final RecordQueue recordQueue = partitionQueues.get(partition); final RecordQueue recordQueue = partitionQueues.get(partition);
@ -328,13 +305,12 @@ public class PartitionGroup {
return Collections.unmodifiableSet(partitionQueues.keySet()); return Collections.unmodifiableSet(partitionQueues.keySet());
} }
/** @Override
* Return the stream-time of this partition group defined as the largest timestamp seen across all partitions
*/
long streamTime() { long streamTime() {
return streamTime; return streamTime;
} }
@Override
Long headRecordOffset(final TopicPartition partition) { Long headRecordOffset(final TopicPartition partition) {
final RecordQueue recordQueue = partitionQueues.get(partition); final RecordQueue recordQueue = partitionQueues.get(partition);
@ -348,6 +324,7 @@ public class PartitionGroup {
/** /**
* @throws IllegalStateException if the record's partition does not belong to this partition group * @throws IllegalStateException if the record's partition does not belong to this partition group
*/ */
@Override
int numBuffered(final TopicPartition partition) { int numBuffered(final TopicPartition partition) {
final RecordQueue recordQueue = partitionQueues.get(partition); final RecordQueue recordQueue = partitionQueues.get(partition);
@ -358,6 +335,7 @@ public class PartitionGroup {
return recordQueue.size(); return recordQueue.size();
} }
@Override
int numBuffered() { int numBuffered() {
return totalBuffered; return totalBuffered;
} }
@ -367,6 +345,7 @@ public class PartitionGroup {
return allBuffered; return allBuffered;
} }
@Override
void clear() { void clear() {
for (final RecordQueue queue : partitionQueues.values()) { for (final RecordQueue queue : partitionQueues.values()) {
queue.clear(); queue.clear();
@ -377,12 +356,14 @@ public class PartitionGroup {
fetchedLags.clear(); fetchedLags.clear();
} }
@Override
void close() { void close() {
for (final RecordQueue queue : partitionQueues.values()) { for (final RecordQueue queue : partitionQueues.values()) {
queue.close(); queue.close();
} }
} }
@Override
void updateLags() { void updateLags() {
if (maxTaskIdleMs != StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) { if (maxTaskIdleMs != StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) {
for (final TopicPartition tp : partitionQueues.keySet()) { for (final TopicPartition tp : partitionQueues.keySet()) {

42
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java

@ -39,6 +39,7 @@ import org.apache.kafka.streams.processor.Punctuator;
import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.TimestampExtractor; import org.apache.kafka.streams.processor.TimestampExtractor;
import org.apache.kafka.streams.processor.api.Record; import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics; import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics;
import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
@ -64,7 +65,7 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
/** /**
* A StreamTask is associated with a {@link PartitionGroup}, and is assigned to a StreamThread for processing. * A StreamTask is associated with a {@link AbstractPartitionGroup}, and is assigned to a StreamThread for processing.
*/ */
public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, Task { public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, Task {
@ -77,9 +78,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
private final boolean eosEnabled; private final boolean eosEnabled;
private final int maxBufferedSize; private final int maxBufferedSize;
private final PartitionGroup partitionGroup; private final AbstractPartitionGroup partitionGroup;
private final RecordCollector recordCollector; private final RecordCollector recordCollector;
private final PartitionGroup.RecordInfo recordInfo; private final AbstractPartitionGroup.RecordInfo recordInfo;
private final Map<TopicPartition, Long> consumedOffsets; private final Map<TopicPartition, Long> consumedOffsets;
private final Map<TopicPartition, Long> committedOffsets; private final Map<TopicPartition, Long> committedOffsets;
private final Map<TopicPartition, Long> highWatermark; private final Map<TopicPartition, Long> highWatermark;
@ -111,7 +112,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
private boolean hasPendingTxCommit = false; private boolean hasPendingTxCommit = false;
private Optional<Long> timeCurrentIdlingStarted; private Optional<Long> timeCurrentIdlingStarted;
@SuppressWarnings({"rawtypes", "this-escape"}) @SuppressWarnings({"rawtypes", "this-escape", "checkstyle:ParameterNumber"})
public StreamTask(final TaskId id, public StreamTask(final TaskId id,
final Set<TopicPartition> inputPartitions, final Set<TopicPartition> inputPartitions,
final ProcessorTopology topology, final ProcessorTopology topology,
@ -124,7 +125,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
final ProcessorStateManager stateMgr, final ProcessorStateManager stateMgr,
final RecordCollector recordCollector, final RecordCollector recordCollector,
final InternalProcessorContext processorContext, final InternalProcessorContext processorContext,
final LogContext logContext) { final LogContext logContext,
final boolean processingThreadsEnabled
) {
super( super(
id, id,
topology, topology,
@ -181,19 +184,30 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
recordQueueCreator = new RecordQueueCreator(this.logContext, config.timestampExtractor, config.deserializationExceptionHandler); recordQueueCreator = new RecordQueueCreator(this.logContext, config.timestampExtractor, config.deserializationExceptionHandler);
recordInfo = new PartitionGroup.RecordInfo(); recordInfo = new RecordInfo();
final Sensor enforcedProcessingSensor; final Sensor enforcedProcessingSensor;
enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics); enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics);
final long maxTaskIdleMs = config.maxTaskIdleMs; final long maxTaskIdleMs = config.maxTaskIdleMs;
partitionGroup = new PartitionGroup( if (processingThreadsEnabled) {
logContext, partitionGroup = new SynchronizedPartitionGroup(new PartitionGroup(
createPartitionQueues(), logContext,
mainConsumer::currentLag, createPartitionQueues(),
TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics), mainConsumer::currentLag,
enforcedProcessingSensor, TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
maxTaskIdleMs enforcedProcessingSensor,
); maxTaskIdleMs
));
} else {
partitionGroup = new PartitionGroup(
logContext,
createPartitionQueues(),
mainConsumer::currentLag,
TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
enforcedProcessingSensor,
maxTaskIdleMs
);
}
stateMgr.registerGlobalStateStores(topology.globalStateStores()); stateMgr.registerGlobalStateStores(topology.globalStateStores());
committedOffsets = new HashMap<>(); committedOffsets = new HashMap<>();

131
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java

@ -49,6 +49,8 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager.DefaultTaskExecutorCreator;
import org.apache.kafka.streams.state.internals.ThreadCache; import org.apache.kafka.streams.state.internals.ThreadCache;
import java.util.Queue; import java.util.Queue;
@ -331,6 +333,7 @@ public class StreamThread extends Thread {
private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false); private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false);
private final boolean eosEnabled; private final boolean eosEnabled;
private final boolean stateUpdaterEnabled; private final boolean stateUpdaterEnabled;
private final boolean processingThreadsEnabled;
public static StreamThread create(final TopologyMetadata topologyMetadata, public static StreamThread create(final TopologyMetadata topologyMetadata,
final StreamsConfig config, final StreamsConfig config,
@ -375,6 +378,7 @@ public class StreamThread extends Thread {
final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics); final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics);
final boolean stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals()); final boolean stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals());
final boolean proceessingThreadsEnabled = InternalConfig.getProcessingThreadsEnabled(config.originals());
final ActiveTaskCreator activeTaskCreator = new ActiveTaskCreator( final ActiveTaskCreator activeTaskCreator = new ActiveTaskCreator(
topologyMetadata, topologyMetadata,
config, config,
@ -387,7 +391,9 @@ public class StreamThread extends Thread {
threadId, threadId,
processId, processId,
log, log,
stateUpdaterEnabled); stateUpdaterEnabled,
proceessingThreadsEnabled
);
final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator( final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator(
topologyMetadata, topologyMetadata,
config, config,
@ -398,6 +404,15 @@ public class StreamThread extends Thread {
log, log,
stateUpdaterEnabled); stateUpdaterEnabled);
final Tasks tasks = new Tasks(new LogContext(logPrefix));
final boolean processingThreadsEnabled =
InternalConfig.getProcessingThreadsEnabled(config.originals());
final DefaultTaskManager schedulingTaskManager =
maybeCreateSchedulingTaskManager(processingThreadsEnabled, stateUpdaterEnabled, topologyMetadata, time, threadId, tasks);
final StateUpdater stateUpdater =
maybeCreateAndStartStateUpdater(stateUpdaterEnabled, streamsMetrics, config, changelogReader, topologyMetadata, time, clientId, threadIdx);
final TaskManager taskManager = new TaskManager( final TaskManager taskManager = new TaskManager(
time, time,
changelogReader, changelogReader,
@ -405,11 +420,12 @@ public class StreamThread extends Thread {
logPrefix, logPrefix,
activeTaskCreator, activeTaskCreator,
standbyTaskCreator, standbyTaskCreator,
new Tasks(new LogContext(logPrefix)), tasks,
topologyMetadata, topologyMetadata,
adminClient, adminClient,
stateDirectory, stateDirectory,
maybeCreateAndStartStateUpdater(stateUpdaterEnabled, streamsMetrics, config, changelogReader, topologyMetadata, time, clientId, threadIdx) stateUpdater,
schedulingTaskManager
); );
referenceContainer.taskManager = taskManager; referenceContainer.taskManager = taskManager;
@ -452,6 +468,31 @@ public class StreamThread extends Thread {
return streamThread.updateThreadMetadata(getSharedAdminClientId(clientId)); return streamThread.updateThreadMetadata(getSharedAdminClientId(clientId));
} }
private static DefaultTaskManager maybeCreateSchedulingTaskManager(final boolean processingThreadsEnabled,
final boolean stateUpdaterEnabled,
final TopologyMetadata topologyMetadata,
final Time time,
final String threadId,
final Tasks tasks) {
if (processingThreadsEnabled) {
if (!stateUpdaterEnabled) {
throw new IllegalStateException("Processing threads require the state updater to be enabled");
}
final DefaultTaskManager defaultTaskManager = new DefaultTaskManager(
time,
threadId,
tasks,
new DefaultTaskExecutorCreator(),
topologyMetadata.taskExecutionMetadata(),
1
);
defaultTaskManager.startTaskExecutors();
return defaultTaskManager;
}
return null;
}
private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled, private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled,
final StreamsMetricsImpl streamsMetrics, final StreamsMetricsImpl streamsMetrics,
final StreamsConfig streamsConfig, final StreamsConfig streamsConfig,
@ -488,7 +529,8 @@ public class StreamThread extends Thread {
final Queue<StreamsException> nonFatalExceptionsToHandle, final Queue<StreamsException> nonFatalExceptionsToHandle,
final Runnable shutdownErrorHook, final Runnable shutdownErrorHook,
final BiConsumer<Throwable, Boolean> streamsUncaughtExceptionHandler, final BiConsumer<Throwable, Boolean> streamsUncaughtExceptionHandler,
final java.util.function.Consumer<Long> cacheResizer) { final java.util.function.Consumer<Long> cacheResizer
) {
super(threadId); super(threadId);
this.stateLock = new Object(); this.stateLock = new Object();
this.adminClient = adminClient; this.adminClient = adminClient;
@ -558,6 +600,7 @@ public class StreamThread extends Thread {
this.numIterations = 1; this.numIterations = 1;
this.eosEnabled = eosEnabled(config); this.eosEnabled = eosEnabled(config);
this.stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals()); this.stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals());
this.processingThreadsEnabled = InternalConfig.getProcessingThreadsEnabled(config.originals());
} }
private static final class InternalConsumerConfig extends ConsumerConfig { private static final class InternalConsumerConfig extends ConsumerConfig {
@ -620,7 +663,11 @@ public class StreamThread extends Thread {
if (size != -1L) { if (size != -1L) {
cacheResizer.accept(size); cacheResizer.accept(size);
} }
runOnce(); if (processingThreadsEnabled) {
runOnceWithProcessingThreads();
} else {
runOnceWithoutProcessingThreads();
}
// Check for a scheduled rebalance but don't trigger it until the current rebalance is done // Check for a scheduled rebalance but don't trigger it until the current rebalance is done
if (!taskManager.rebalanceInProgress() && nextProbingRebalanceMs.get() < time.milliseconds()) { if (!taskManager.rebalanceInProgress() && nextProbingRebalanceMs.get() < time.milliseconds()) {
@ -765,7 +812,7 @@ public class StreamThread extends Thread {
* or if the task producer got fenced (EOS) * or if the task producer got fenced (EOS)
*/ */
// Visible for testing // Visible for testing
void runOnce() { void runOnceWithoutProcessingThreads() {
final long startMs = time.milliseconds(); final long startMs = time.milliseconds();
now = startMs; now = startMs;
@ -900,6 +947,78 @@ public class StreamThread extends Thread {
} }
} }
/**
* One iteration of a thread includes the following steps:
*
* 1. poll records from main consumer and add to buffer;
* 2. check the task manager for any exceptions to be handled
* 3. commit all tasks if necessary;
*
* @throws IllegalStateException If store gets registered after initialized is already finished
* @throws StreamsException If the store's change log does not contain the partition
* @throws TaskMigratedException If another thread wrote to the changelog topic that is currently restored
* or if committing offsets failed (non-EOS)
* or if the task producer got fenced (EOS)
*/
// Visible for testing
void runOnceWithProcessingThreads() {
final long startMs = time.milliseconds();
now = startMs;
final long pollLatency;
taskManager.resumePollingForPartitionsWithAvailableSpace();
try {
pollLatency = pollPhase();
} finally {
taskManager.updateLags();
}
// Shutdown hook could potentially be triggered and transit the thread state to PENDING_SHUTDOWN during #pollRequests().
// The task manager internal states could be uninitialized if the state transition happens during #onPartitionsAssigned().
// Should only proceed when the thread is still running after #pollRequests(), because no external state mutation
// could affect the task manager state beyond this point within #runOnce().
if (!isRunning()) {
log.debug("Thread state is already {}, skipping the run once call after poll request", state);
return;
}
long totalCommitLatency = 0L;
if (isRunning()) {
checkStateUpdater();
taskManager.maybeThrowTaskExceptionsFromProcessingThreads();
taskManager.signalTaskExecutors();
final long beforeCommitMs = now;
final int committed = maybeCommit();
final long commitLatency = Math.max(now - beforeCommitMs, 0);
totalCommitLatency += commitLatency;
if (committed > 0) {
totalCommittedSinceLastSummary += committed;
commitSensor.record(commitLatency / (double) committed, now);
if (log.isDebugEnabled()) {
log.debug("Committed all active tasks {} and standby tasks {} in {}ms",
taskManager.activeTaskIds(), taskManager.standbyTaskIds(), commitLatency);
}
}
}
now = time.milliseconds();
final long runOnceLatency = now - startMs;
pollRatioSensor.record((double) pollLatency / runOnceLatency, now);
commitRatioSensor.record((double) totalCommitLatency / runOnceLatency, now);
final boolean logProcessingSummary = now - lastLogSummaryMs > LOG_SUMMARY_INTERVAL_MS;
if (logProcessingSummary) {
log.info("Committed {} total tasks since the last update", totalCommittedSinceLastSummary);
totalCommittedSinceLastSummary = 0L;
lastLogSummaryMs = now;
}
}
private void initializeAndRestorePhase() { private void initializeAndRestorePhase() {
final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = partitions -> resetOffsets(partitions, null); final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = partitions -> resetOffsets(partitions, null);
final State stateSnapshot = state; final State stateSnapshot = state;

101
streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import java.util.Set;
import java.util.function.Function;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
class SynchronizedPartitionGroup extends AbstractPartitionGroup {
private final AbstractPartitionGroup wrapped;
public SynchronizedPartitionGroup(final AbstractPartitionGroup wrapped) {
this.wrapped = wrapped;
}
@Override
synchronized boolean readyToProcess(final long wallClockTime) {
return wrapped.readyToProcess(wallClockTime);
}
@Override
synchronized void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
wrapped.updatePartitions(inputPartitions, recordQueueCreator);
}
@Override
synchronized void setPartitionTime(final TopicPartition partition, final long partitionTime) {
wrapped.setPartitionTime(partition, partitionTime);
}
@Override
synchronized StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) {
return wrapped.nextRecord(info, wallClockTime);
}
@Override
synchronized int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
return wrapped.addRawRecords(partition, rawRecords);
}
@Override
synchronized long partitionTimestamp(final TopicPartition partition) {
return wrapped.partitionTimestamp(partition);
}
@Override
synchronized long streamTime() {
return wrapped.streamTime();
}
@Override
synchronized Long headRecordOffset(final TopicPartition partition) {
return wrapped.headRecordOffset(partition);
}
@Override
synchronized int numBuffered() {
return wrapped.numBuffered();
}
@Override
synchronized int numBuffered(final TopicPartition tp) {
return wrapped.numBuffered(tp);
}
@Override
synchronized void clear() {
wrapped.clear();
}
@Override
synchronized void updateLags() {
wrapped.updateLags();
}
@Override
synchronized void close() {
wrapped.close();
}
@Override
synchronized Set<TopicPartition> partitions() {
return wrapped.partitions();
}
}

106
streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java

@ -16,6 +16,7 @@
*/ */
package org.apache.kafka.streams.processor.internals; package org.apache.kafka.streams.processor.internals;
import java.util.concurrent.ExecutionException;
import org.apache.kafka.clients.admin.Admin; import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.DeleteRecordsResult; import org.apache.kafka.clients.admin.DeleteRecordsResult;
import org.apache.kafka.clients.admin.RecordsToDelete; import org.apache.kafka.clients.admin.RecordsToDelete;
@ -39,6 +40,7 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory; import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
import org.apache.kafka.streams.processor.internals.Task.State; import org.apache.kafka.streams.processor.internals.Task.State;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint; import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -97,6 +99,7 @@ public class TaskManager {
private final ActiveTaskCreator activeTaskCreator; private final ActiveTaskCreator activeTaskCreator;
private final StandbyTaskCreator standbyTaskCreator; private final StandbyTaskCreator standbyTaskCreator;
private final StateUpdater stateUpdater; private final StateUpdater stateUpdater;
private final DefaultTaskManager schedulingTaskManager;
TaskManager(final Time time, TaskManager(final Time time,
final ChangelogReader changelogReader, final ChangelogReader changelogReader,
@ -108,7 +111,9 @@ public class TaskManager {
final TopologyMetadata topologyMetadata, final TopologyMetadata topologyMetadata,
final Admin adminClient, final Admin adminClient,
final StateDirectory stateDirectory, final StateDirectory stateDirectory,
final StateUpdater stateUpdater) { final StateUpdater stateUpdater,
final DefaultTaskManager schedulingTaskManager
) {
this.time = time; this.time = time;
this.processId = processId; this.processId = processId;
this.logPrefix = logPrefix; this.logPrefix = logPrefix;
@ -124,6 +129,7 @@ public class TaskManager {
this.log = logContext.logger(getClass()); this.log = logContext.logger(getClass());
this.stateUpdater = stateUpdater; this.stateUpdater = stateUpdater;
this.schedulingTaskManager = schedulingTaskManager;
this.tasks = tasks; this.tasks = tasks;
this.taskExecutor = new TaskExecutor( this.taskExecutor = new TaskExecutor(
this.tasks, this.tasks,
@ -202,6 +208,11 @@ public class TaskManager {
* @throws TaskMigratedException * @throws TaskMigratedException
*/ */
boolean handleCorruption(final Set<TaskId> corruptedTasks) { boolean handleCorruption(final Set<TaskId> corruptedTasks) {
final Set<TaskId> activeTasks = new HashSet<>(tasks.activeTaskIds());
// We need to stop all processing, since we need to commit non-corrupted tasks as well.
maybeLockTasks(activeTasks);
final Set<Task> corruptedActiveTasks = new HashSet<>(); final Set<Task> corruptedActiveTasks = new HashSet<>();
final Set<Task> corruptedStandbyTasks = new HashSet<>(); final Set<Task> corruptedStandbyTasks = new HashSet<>();
@ -240,6 +251,9 @@ public class TaskManager {
} }
closeDirtyAndRevive(corruptedActiveTasks, true); closeDirtyAndRevive(corruptedActiveTasks, true);
maybeUnlockTasks(activeTasks);
return !corruptedActiveTasks.isEmpty(); return !corruptedActiveTasks.isEmpty();
} }
@ -329,6 +343,13 @@ public class TaskManager {
final Map<Task, Set<TopicPartition>> tasksToRecycle = new HashMap<>(); final Map<Task, Set<TopicPartition>> tasksToRecycle = new HashMap<>();
final Set<Task> tasksToCloseClean = new TreeSet<>(Comparator.comparing(Task::id)); final Set<Task> tasksToCloseClean = new TreeSet<>(Comparator.comparing(Task::id));
final Set<TaskId> tasksToLock =
tasks.allTaskIds().stream()
.filter(x -> activeTasksToCreate.containsKey(x) || standbyTasksToCreate.containsKey(x))
.collect(Collectors.toSet());
maybeLockTasks(tasksToLock);
// first put aside those unrecognized tasks because of unknown named-topologies // first put aside those unrecognized tasks because of unknown named-topologies
tasks.clearPendingTasksToCreate(); tasks.clearPendingTasksToCreate();
tasks.addPendingActiveTasksToCreate(pendingTasksToCreate(activeTasksToCreate)); tasks.addPendingActiveTasksToCreate(pendingTasksToCreate(activeTasksToCreate));
@ -346,6 +367,8 @@ public class TaskManager {
final Map<TaskId, RuntimeException> taskCloseExceptions = closeAndRecycleTasks(tasksToRecycle, tasksToCloseClean); final Map<TaskId, RuntimeException> taskCloseExceptions = closeAndRecycleTasks(tasksToRecycle, tasksToCloseClean);
maybeUnlockTasks(tasksToLock);
maybeThrowTaskExceptions(taskCloseExceptions); maybeThrowTaskExceptions(taskCloseExceptions);
createNewTasks(activeTasksToCreate, standbyTasksToCreate); createNewTasks(activeTasksToCreate, standbyTasksToCreate);
@ -964,6 +987,9 @@ public class TaskManager {
final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsPerTask = new HashMap<>(); final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsPerTask = new HashMap<>();
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null); final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
final Set<TaskId> lockedTaskIds = activeRunningTaskIterable().stream().map(Task::id).collect(Collectors.toSet());
maybeLockTasks(lockedTaskIds);
for (final Task task : activeRunningTaskIterable()) { for (final Task task : activeRunningTaskIterable()) {
if (remainingRevokedPartitions.containsAll(task.inputPartitions())) { if (remainingRevokedPartitions.containsAll(task.inputPartitions())) {
// when the task input partitions are included in the revoked list, // when the task input partitions are included in the revoked list,
@ -1057,6 +1083,8 @@ public class TaskManager {
} }
} }
maybeUnlockTasks(lockedTaskIds);
if (firstException.get() != null) { if (firstException.get() != null) {
throw firstException.get(); throw firstException.get();
} }
@ -1112,6 +1140,8 @@ public class TaskManager {
private void closeRunningTasksDirty() { private void closeRunningTasksDirty() {
final Set<Task> allTask = tasks.allTasks(); final Set<Task> allTask = tasks.allTasks();
final Set<TaskId> allTaskIds = tasks.allTaskIds();
maybeLockTasks(allTaskIds);
for (final Task task : allTask) { for (final Task task : allTask) {
// Even though we've apparently dropped out of the group, we can continue safely to maintain our // Even though we've apparently dropped out of the group, we can continue safely to maintain our
// standby tasks while we rejoin. // standby tasks while we rejoin.
@ -1119,6 +1149,7 @@ public class TaskManager {
closeTaskDirty(task, true); closeTaskDirty(task, true);
} }
} }
maybeUnlockTasks(allTaskIds);
} }
private void removeLostActiveTasksFromStateUpdater() { private void removeLostActiveTasksFromStateUpdater() {
@ -1136,6 +1167,9 @@ public class TaskManager {
if (stateUpdater != null) { if (stateUpdater != null) {
stateUpdater.signalResume(); stateUpdater.signalResume();
} }
if (schedulingTaskManager != null) {
schedulingTaskManager.signalTaskExecutors();
}
} }
/** /**
@ -1307,6 +1341,7 @@ public class TaskManager {
void shutdown(final boolean clean) { void shutdown(final boolean clean) {
shutdownStateUpdater(); shutdownStateUpdater();
shutdownSchedulingTaskManager();
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null); final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
@ -1356,6 +1391,12 @@ public class TaskManager {
} }
} }
private void shutdownSchedulingTaskManager() {
if (schedulingTaskManager != null) {
schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE));
}
}
private void closeFailedTasksFromStateUpdater() { private void closeFailedTasksFromStateUpdater() {
final Set<Task> tasksToCloseDirty = stateUpdater.drainExceptionsAndFailedTasks().stream() final Set<Task> tasksToCloseDirty = stateUpdater.drainExceptionsAndFailedTasks().stream()
.flatMap(exAndTasks -> exAndTasks.getTasks().stream()).collect(Collectors.toSet()); .flatMap(exAndTasks -> exAndTasks.getTasks().stream()).collect(Collectors.toSet());
@ -1415,6 +1456,12 @@ public class TaskManager {
void closeAndCleanUpTasks(final Collection<Task> activeTasks, final Collection<Task> standbyTasks, final boolean clean) { void closeAndCleanUpTasks(final Collection<Task> activeTasks, final Collection<Task> standbyTasks, final boolean clean) {
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null); final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
final Set<TaskId> ids =
activeTasks.stream()
.map(Task::id)
.collect(Collectors.toSet());
maybeLockTasks(ids);
final Set<Task> tasksToCloseDirty = new HashSet<>(); final Set<Task> tasksToCloseDirty = new HashSet<>();
tasksToCloseDirty.addAll(tryCloseCleanActiveTasks(activeTasks, clean, firstException)); tasksToCloseDirty.addAll(tryCloseCleanActiveTasks(activeTasks, clean, firstException));
tasksToCloseDirty.addAll(tryCloseCleanStandbyTasks(standbyTasks, clean, firstException)); tasksToCloseDirty.addAll(tryCloseCleanStandbyTasks(standbyTasks, clean, firstException));
@ -1423,6 +1470,8 @@ public class TaskManager {
closeTaskDirty(task, true); closeTaskDirty(task, true);
} }
maybeUnlockTasks(ids);
final RuntimeException exception = firstException.get(); final RuntimeException exception = firstException.get();
if (exception != null) { if (exception != null) {
throw exception; throw exception;
@ -1681,6 +1730,16 @@ public class TaskManager {
} }
} }
/**
* Wake-up any sleeping processing threads.
*/
public void signalTaskExecutors() {
if (schedulingTaskManager != null) {
// Wake up sleeping task executors after every poll, in case there is processing or punctuation to-do.
schedulingTaskManager.signalTaskExecutors();
}
}
/** /**
* Take records and add them to each respective task * Take records and add them to each respective task
* *
@ -1700,6 +1759,42 @@ public class TaskManager {
} }
} }
private void maybeLockTasks(final Set<TaskId> ids) {
if (schedulingTaskManager != null && !ids.isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("Locking tasks {}", ids.stream().map(TaskId::toString).collect(Collectors.joining(", ")));
}
boolean locked = false;
while (!locked) {
try {
schedulingTaskManager.lockTasks(ids).get();
locked = true;
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for tasks {} to be locked",
ids.stream().map(TaskId::toString).collect(Collectors.joining(",")));
} catch (final ExecutionException e) {
log.info("Failed to lock tasks");
throw new RuntimeException(e);
}
}
}
}
private void maybeUnlockTasks(final Set<TaskId> ids) {
if (schedulingTaskManager != null && !ids.isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("Unlocking tasks {}", ids.stream().map(TaskId::toString).collect(Collectors.joining(", ")));
}
schedulingTaskManager.unlockTasks(ids);
}
}
public void maybeThrowTaskExceptionsFromProcessingThreads() {
if (schedulingTaskManager != null) {
maybeThrowTaskExceptions(schedulingTaskManager.drainUncaughtExceptions());
}
}
/** /**
* @throws TaskMigratedException if committing offsets failed (non-EOS) * @throws TaskMigratedException if committing offsets failed (non-EOS)
* or if the task producer got fenced (EOS) * or if the task producer got fenced (EOS)
@ -1709,6 +1804,14 @@ public class TaskManager {
*/ */
int commit(final Collection<Task> tasksToCommit) { int commit(final Collection<Task> tasksToCommit) {
int committed = 0; int committed = 0;
final Set<TaskId> ids =
tasksToCommit.stream()
.map(Task::id)
.collect(Collectors.toSet());
maybeLockTasks(ids);
// We have to throw the first uncaught exception after locking the tasks, to not attempt to commit failure records.
maybeThrowTaskExceptionsFromProcessingThreads();
final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>(); final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
try { try {
@ -1719,6 +1822,7 @@ public class TaskManager {
.forEach(t -> t.maybeInitTaskTimeoutOrThrow(time.milliseconds(), timeoutException)); .forEach(t -> t.maybeInitTaskTimeoutOrThrow(time.milliseconds(), timeoutException));
} }
maybeUnlockTasks(ids);
return committed; return committed;
} }

5
streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java

@ -349,6 +349,11 @@ class Tasks implements TasksRegistry {
return tasks; return tasks;
} }
@Override
public synchronized Collection<TaskId> activeTaskIds() {
return Collections.unmodifiableCollection(activeTasksPerId.keySet());
}
@Override @Override
public synchronized Collection<Task> activeTasks() { public synchronized Collection<Task> activeTasks() {
return Collections.unmodifiableCollection(activeTasksPerId.values()); return Collections.unmodifiableCollection(activeTasksPerId.values());

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java

@ -85,6 +85,8 @@ public interface TasksRegistry {
Collection<Task> tasks(final Collection<TaskId> taskIds); Collection<Task> tasks(final Collection<TaskId> taskIds);
Collection<TaskId> activeTaskIds();
Collection<Task> activeTasks(); Collection<Task> activeTasks();
Set<Task> allTasks(); Set<Task> allTasks();

11
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskExecutor.java

@ -194,7 +194,16 @@ public class DefaultTaskExecutor implements TaskExecutor {
// flush the task before giving it back to task manager, if we are not handing it back because of an error. // flush the task before giving it back to task manager, if we are not handing it back because of an error.
if (!taskManager.hasUncaughtException(currentTask.id())) { if (!taskManager.hasUncaughtException(currentTask.id())) {
currentTask.flush(); try {
currentTask.flush();
} catch (final StreamsException e) {
log.error(String.format("Failed to flush stream task %s due to the following error:", currentTask.id()), e);
e.setTaskId(currentTask.id());
taskManager.setUncaughtException(e, currentTask.id());
} catch (final RuntimeException e) {
log.error(String.format("Failed to flush stream task %s due to the following error:", currentTask.id()), e);
taskManager.setUncaughtException(new StreamsException(e, currentTask.id()), currentTask.id());
}
} }
taskManager.unassignTask(currentTask, DefaultTaskExecutor.this); taskManager.unassignTask(currentTask, DefaultTaskExecutor.this);

6
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java

@ -319,9 +319,9 @@ public class DefaultTaskManager implements TaskManager {
exception.getMessage()); exception.getMessage());
} }
public Map<TaskId, StreamsException> drainUncaughtExceptions() { public Map<TaskId, RuntimeException> drainUncaughtExceptions() {
final Map<TaskId, StreamsException> returnValue = returnWithTasksLocked(() -> { final Map<TaskId, RuntimeException> returnValue = returnWithTasksLocked(() -> {
final Map<TaskId, StreamsException> result = new HashMap<>(uncaughtExceptions); final Map<TaskId, RuntimeException> result = new HashMap<>(uncaughtExceptions);
uncaughtExceptions.clear(); uncaughtExceptions.clear();
return result; return result;
}); });

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/TaskManager.java

@ -118,7 +118,7 @@ public interface TaskManager {
* *
* @return A map from task ID to the exception that occurred. * @return A map from task ID to the exception that occurred.
*/ */
Map<TaskId, StreamsException> drainUncaughtExceptions(); Map<TaskId, RuntimeException> drainUncaughtExceptions();
/** /**
* Can be used to check if a specific task has an uncaught exception. * Can be used to check if a specific task has an uncaught exception.

2
streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java

@ -442,7 +442,7 @@ public class AdjustStreamThreadCountTest {
@Override @Override
public void init(final ProcessorContext context) { public void init(final ProcessorContext context) {
context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> { context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> {
if (Thread.currentThread().getName().endsWith("StreamThread-1") && injectError.get()) { if (Thread.currentThread().getName().contains("StreamThread-1") && injectError.get()) {
injectError.set(false); injectError.set(false);
throw new RuntimeException("BOOM"); throw new RuntimeException("BOOM");
} }

31
streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java

@ -34,6 +34,7 @@ import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.StreamsConfig; import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.StreamsConfig.InternalConfig;
import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
import org.apache.kafka.streams.kstream.KStream; import org.apache.kafka.streams.kstream.KStream;
@ -149,18 +150,24 @@ public class EosIntegrationTest {
private String stateTmpDir; private String stateTmpDir;
@SuppressWarnings("deprecation") @SuppressWarnings("deprecation")
@Parameters(name = "{0}") @Parameters(name = "{0}, processing threads = {1}")
public static Collection<String[]> data() { public static Collection<Object[]> data() {
return Arrays.asList(new String[][]{ return Arrays.asList(new Object[][]{
{StreamsConfig.AT_LEAST_ONCE}, {StreamsConfig.AT_LEAST_ONCE, false},
{StreamsConfig.EXACTLY_ONCE}, {StreamsConfig.EXACTLY_ONCE, false},
{StreamsConfig.EXACTLY_ONCE_V2} {StreamsConfig.EXACTLY_ONCE_V2, false},
{StreamsConfig.AT_LEAST_ONCE, true},
{StreamsConfig.EXACTLY_ONCE, true},
{StreamsConfig.EXACTLY_ONCE_V2, true}
}); });
} }
@Parameter @Parameter(0)
public String eosConfig; public String eosConfig;
@Parameter(1)
public boolean processingThreadsEnabled;
@Before @Before
public void createTopics() throws Exception { public void createTopics() throws Exception {
applicationId = "appId-" + TEST_NUMBER.getAndIncrement(); applicationId = "appId-" + TEST_NUMBER.getAndIncrement();
@ -876,10 +883,15 @@ public class EosIntegrationTest {
LOG.info(dummyHostName + " is executing the injected stall"); LOG.info(dummyHostName + " is executing the injected stall");
stallingHost.set(dummyHostName); stallingHost.set(dummyHostName);
while (doStall) { while (doStall) {
final StreamThread thread = (StreamThread) Thread.currentThread(); final Thread thread = Thread.currentThread();
if (thread.isInterrupted() || !thread.isRunning()) { if (thread.isInterrupted()) {
throw new RuntimeException("Detected we've been interrupted."); throw new RuntimeException("Detected we've been interrupted.");
} }
if (!processingThreadsEnabled) {
if (!((StreamThread) thread).isRunning()) {
throw new RuntimeException("Detected we've been interrupted.");
}
}
try { try {
Thread.sleep(100); Thread.sleep(100);
} catch (final InterruptedException e) { } catch (final InterruptedException e) {
@ -943,6 +955,7 @@ public class EosIntegrationTest {
properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0); properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
properties.put(StreamsConfig.STATE_DIR_CONFIG, stateTmpDir + appDir); properties.put(StreamsConfig.STATE_DIR_CONFIG, stateTmpDir + appDir);
properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, dummyHostName + ":2142"); properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, dummyHostName + ":2142");
properties.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled);
final Properties config = StreamsTestUtils.getStreamsConfig( final Properties config = StreamsTestUtils.getStreamsConfig(
applicationId, applicationId,

13
streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java

@ -30,6 +30,7 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import java.io.IOException; import java.io.IOException;
@ -94,11 +95,12 @@ public class SmokeTestDriverIntegrationTest {
} }
private static Stream<Boolean> parameters() { private static Stream<Arguments> parameters() {
return Stream.of( return Stream.of(
Boolean.TRUE, Arguments.of(false, false),
Boolean.FALSE Arguments.of(true, false),
); Arguments.of(true, true)
);
} }
// In this test, we try to keep creating new stream, and closing the old one, to maintain only 3 streams alive. // In this test, we try to keep creating new stream, and closing the old one, to maintain only 3 streams alive.
@ -107,7 +109,7 @@ public class SmokeTestDriverIntegrationTest {
// (1) 10 min timeout, (2) 30 tries of polling without getting any data // (1) 10 min timeout, (2) 30 tries of polling without getting any data
@ParameterizedTest @ParameterizedTest
@MethodSource("parameters") @MethodSource("parameters")
public void shouldWorkWithRebalance(final boolean stateUpdaterEnabled) throws InterruptedException { public void shouldWorkWithRebalance(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) throws InterruptedException {
Exit.setExitProcedure((statusCode, message) -> { Exit.setExitProcedure((statusCode, message) -> {
throw new AssertionError("Test called exit(). code:" + statusCode + " message:" + message); throw new AssertionError("Test called exit(). code:" + statusCode + " message:" + message);
}); });
@ -128,6 +130,7 @@ public class SmokeTestDriverIntegrationTest {
final Properties props = new Properties(); final Properties props = new Properties();
props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled); props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled);
props.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled);
// decrease the session timeout so that we can trigger the rebalance soon after old client left closed // decrease the session timeout so that we can trigger the rebalance soon after old client left closed
props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000);

1
streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java

@ -497,6 +497,7 @@ public class ActiveTaskCreatorTest {
"clientId-StreamThread-0", "clientId-StreamThread-0",
uuid, uuid,
new LogContext().logger(ActiveTaskCreator.class), new LogContext().logger(ActiveTaskCreator.class),
false,
false); false);
assertThat( assertThat(

26
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java

@ -33,6 +33,7 @@ import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
import org.apache.kafka.streams.processor.TimestampExtractor; import org.apache.kafka.streams.processor.TimestampExtractor;
import org.apache.kafka.common.utils.LogCaptureAppender; import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.apache.kafka.test.InternalMockProcessorContext; import org.apache.kafka.test.InternalMockProcessorContext;
import org.apache.kafka.test.MockSourceNode; import org.apache.kafka.test.MockSourceNode;
import org.apache.kafka.test.MockTimestampExtractor; import org.apache.kafka.test.MockTimestampExtractor;
@ -133,7 +134,7 @@ public class PartitionGroupTest {
private void testFirstBatch(final PartitionGroup group) { private void testFirstBatch(final PartitionGroup group) {
StampedRecord record; StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); final PartitionGroup.RecordInfo info = new RecordInfo();
assertThat(group.numBuffered(), is(0)); assertThat(group.numBuffered(), is(0));
// add three 3 records with timestamp 1, 3, 5 to partition-1 // add three 3 records with timestamp 1, 3, 5 to partition-1
@ -193,7 +194,7 @@ public class PartitionGroupTest {
private void testSecondBatch(final PartitionGroup group) { private void testSecondBatch(final PartitionGroup group) {
StampedRecord record; StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); final PartitionGroup.RecordInfo info = new RecordInfo();
// add 2 more records with timestamp 2, 4 to partition-1 // add 2 more records with timestamp 2, 4 to partition-1
final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList( final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
@ -316,7 +317,7 @@ public class PartitionGroupTest {
assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue()); assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
StampedRecord record; StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); final PartitionGroup.RecordInfo info = new RecordInfo();
// get first two records from partition 1 // get first two records from partition 1
record = group.nextRecord(info, time.milliseconds()); record = group.nextRecord(info, time.milliseconds());
@ -445,15 +446,16 @@ public class PartitionGroupTest {
new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue), new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue),
new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue));
group.addRawRecords(partition1, list); group.addRawRecords(partition1, list);
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); group.nextRecord(new RecordInfo(), time.milliseconds());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); group.nextRecord(new RecordInfo(), time.milliseconds());
group.updateLags(); group.updateLags();
group.clear(); group.clear();
assertThat(group.numBuffered(), equalTo(0)); assertThat(group.numBuffered(), equalTo(0));
assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN)); assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), equalTo(null)); assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), equalTo(null));
assertThat(group.partitionTimestamp(partition1), equalTo(RecordQueue.UNKNOWN)); assertThat(group.partitionTimestamp(partition1), equalTo(RecordQueue.UNKNOWN));
hasNoFetchedLag(group, partition1); hasNoFetchedLag(group, partition1);
@ -475,7 +477,7 @@ public class PartitionGroupTest {
group.addRawRecords(partition2, list2); group.addRawRecords(partition2, list2);
assertEquals(list1.size() + list2.size(), group.numBuffered()); assertEquals(list1.size() + list2.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally()); assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); group.nextRecord(new RecordInfo(), time.milliseconds());
// shrink list of queues // shrink list of queues
group.updatePartitions(mkSet(createPartition2()), p -> { group.updatePartitions(mkSet(createPartition2()), p -> {
@ -487,7 +489,7 @@ public class PartitionGroupTest {
assertEquals(list2.size(), group.numBuffered()); assertEquals(list2.size(), group.numBuffered());
assertEquals(1, group.streamTime()); assertEquals(1, group.streamTime());
assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1)); assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
assertThat(group.partitionTimestamp(partition2), equalTo(2L)); assertThat(group.partitionTimestamp(partition2), equalTo(2L));
} }
@ -508,7 +510,7 @@ public class PartitionGroupTest {
assertEquals(list1.size(), group.numBuffered()); assertEquals(list1.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally()); assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); group.nextRecord(new RecordInfo(), time.milliseconds());
// expand list of queues // expand list of queues
group.updatePartitions(mkSet(createPartition1(), createPartition2()), p -> { group.updatePartitions(mkSet(createPartition1(), createPartition2()), p -> {
@ -521,7 +523,7 @@ public class PartitionGroupTest {
assertEquals(1, group.streamTime()); assertEquals(1, group.streamTime());
assertThat(group.partitionTimestamp(partition1), equalTo(1L)); assertThat(group.partitionTimestamp(partition1), equalTo(1L));
assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN)); assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
} }
@Test @Test
@ -540,7 +542,7 @@ public class PartitionGroupTest {
group.addRawRecords(partition1, list1); group.addRawRecords(partition1, list1);
assertEquals(list1.size(), group.numBuffered()); assertEquals(list1.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally()); assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); group.nextRecord(new RecordInfo(), time.milliseconds());
// expand and shrink list of queues // expand and shrink list of queues
group.updatePartitions(mkSet(createPartition2()), p -> { group.updatePartitions(mkSet(createPartition2()), p -> {
@ -553,7 +555,7 @@ public class PartitionGroupTest {
assertEquals(1, group.streamTime()); assertEquals(1, group.streamTime());
assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1)); assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1));
assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN)); assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), nullValue()); // all available records removed assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), nullValue()); // all available records removed
} }
@Test @Test

32
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java

@ -1841,7 +1841,9 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext); logContext,
false
);
task.initializeIfNeeded(); task.initializeIfNeeded();
task.completeRestoration(noOpResetter -> { }); task.completeRestoration(noOpResetter -> { });
@ -2504,7 +2506,9 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext) logContext,
false
)
); );
assertThat(exception.getMessage(), equalTo("Invalid topology: " + assertThat(exception.getMessage(), equalTo("Invalid topology: " +
@ -2700,7 +2704,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2741,7 +2746,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2774,7 +2780,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2812,7 +2819,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2852,7 +2860,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2893,7 +2902,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2932,7 +2942,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }
@ -2966,7 +2977,8 @@ public class StreamTaskTest {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext logContext,
false
); );
} }

838
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java

File diff suppressed because it is too large Load Diff

201
streams/src/test/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroupTest.java

@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.Collections;
import java.util.Set;
import java.util.function.Function;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
public class SynchronizedPartitionGroupTest {
@Mock
private AbstractPartitionGroup wrapped;
private SynchronizedPartitionGroup synchronizedPartitionGroup;
private AutoCloseable closeable;
@BeforeEach
public void setUp() {
closeable = MockitoAnnotations.openMocks(this);
synchronizedPartitionGroup = new SynchronizedPartitionGroup(wrapped);
}
@AfterEach
public void tearDown() throws Exception {
closeable.close();
}
@Test
public void testReadyToProcess() {
final long wallClockTime = 0L;
when(wrapped.readyToProcess(wallClockTime)).thenReturn(true);
synchronizedPartitionGroup.readyToProcess(wallClockTime);
verify(wrapped, times(1)).readyToProcess(wallClockTime);
}
@Test
public void testUpdatePartitions() {
final Set<TopicPartition> inputPartitions = Collections.singleton(new TopicPartition("topic", 0));
@SuppressWarnings("unchecked") final Function<TopicPartition, RecordQueue> recordQueueCreator = (Function<TopicPartition, RecordQueue>) mock(Function.class);
synchronizedPartitionGroup.updatePartitions(inputPartitions, recordQueueCreator);
verify(wrapped, times(1)).updatePartitions(inputPartitions, recordQueueCreator);
}
@Test
public void testSetPartitionTime() {
final TopicPartition partition = new TopicPartition("topic", 0);
final long partitionTime = 12345678L;
synchronizedPartitionGroup.setPartitionTime(partition, partitionTime);
verify(wrapped, times(1)).setPartitionTime(partition, partitionTime);
}
@Test
public void testNextRecord() {
final RecordInfo info = mock(RecordInfo.class);
final long wallClockTime = 12345678L;
final StampedRecord stampedRecord = mock(StampedRecord.class);
when(wrapped.nextRecord(info, wallClockTime)).thenReturn(stampedRecord);
final StampedRecord result = synchronizedPartitionGroup.nextRecord(info, wallClockTime);
assertEquals(stampedRecord, result);
verify(wrapped, times(1)).nextRecord(info, wallClockTime);
}
@Test
public void testAddRawRecords() {
final TopicPartition partition = new TopicPartition("topic", 0);
@SuppressWarnings("unchecked") final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords = (Iterable<ConsumerRecord<byte[], byte[]>>) mock(Iterable.class);
when(wrapped.addRawRecords(partition, rawRecords)).thenReturn(1);
final int result = synchronizedPartitionGroup.addRawRecords(partition, rawRecords);
assertEquals(1, result);
verify(wrapped, times(1)).addRawRecords(partition, rawRecords);
}
@Test
public void testPartitionTimestamp() {
final TopicPartition partition = new TopicPartition("topic", 0);
final long timestamp = 12345678L;
when(wrapped.partitionTimestamp(partition)).thenReturn(timestamp);
final long result = synchronizedPartitionGroup.partitionTimestamp(partition);
assertEquals(timestamp, result);
verify(wrapped, times(1)).partitionTimestamp(partition);
}
@Test
public void testStreamTime() {
final long streamTime = 12345678L;
when(wrapped.streamTime()).thenReturn(streamTime);
final long result = synchronizedPartitionGroup.streamTime();
assertEquals(streamTime, result);
verify(wrapped, times(1)).streamTime();
}
@Test
public void testHeadRecordOffset() {
final TopicPartition partition = new TopicPartition("topic", 0);
final Long recordOffset = 0L;
when(wrapped.headRecordOffset(partition)).thenReturn(recordOffset);
final Long result = synchronizedPartitionGroup.headRecordOffset(partition);
assertEquals(recordOffset, result);
verify(wrapped, times(1)).headRecordOffset(partition);
}
@Test
public void testNumBuffered() {
final int numBuffered = 1;
when(wrapped.numBuffered()).thenReturn(numBuffered);
final int result = synchronizedPartitionGroup.numBuffered();
assertEquals(numBuffered, result);
verify(wrapped, times(1)).numBuffered();
}
@Test
public void testNumBufferedWithTopicPartition() {
final TopicPartition partition = new TopicPartition("topic", 0);
final int numBuffered = 1;
when(wrapped.numBuffered(partition)).thenReturn(numBuffered);
final int result = synchronizedPartitionGroup.numBuffered(partition);
assertEquals(numBuffered, result);
verify(wrapped, times(1)).numBuffered(partition);
}
@Test
public void testClear() {
synchronizedPartitionGroup.clear();
verify(wrapped, times(1)).clear();
}
@Test
public void testUpdateLags() {
synchronizedPartitionGroup.updateLags();
verify(wrapped, times(1)).updateLags();
}
@Test
public void testClose() {
synchronizedPartitionGroup.close();
verify(wrapped, times(1)).close();
}
@Test
public void testPartitions() {
final Set<TopicPartition> partitions = Collections.singleton(new TopicPartition("topic", 0));
when(wrapped.partitions()).thenReturn(partitions);
final Set<TopicPartition> result = synchronizedPartitionGroup.partitions();
assertEquals(partitions, result);
verify(wrapped, times(1)).partitions();
}
}

124
streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java

@ -26,6 +26,7 @@ import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.KafkaFuture;
import org.apache.kafka.common.Metric; import org.apache.kafka.common.Metric;
import org.apache.kafka.common.MetricName; import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
@ -48,6 +49,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory; import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks; import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
import org.apache.kafka.streams.processor.internals.Task.State; import org.apache.kafka.streams.processor.internals.Task.State;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig; import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig;
import org.apache.kafka.common.utils.LogCaptureAppender; import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint; import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@ -199,6 +201,7 @@ public class TaskManagerTest {
@org.mockito.Mock @org.mockito.Mock
private Admin adminClient; private Admin adminClient;
final StateUpdater stateUpdater = Mockito.mock(StateUpdater.class); final StateUpdater stateUpdater = Mockito.mock(StateUpdater.class);
final DefaultTaskManager schedulingTaskManager = Mockito.mock(DefaultTaskManager.class);
private TaskManager taskManager; private TaskManager taskManager;
private TopologyMetadata topologyMetadata; private TopologyMetadata topologyMetadata;
@ -212,16 +215,21 @@ public class TaskManagerTest {
@Before @Before
public void setUp() { public void setUp() {
taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, false); taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, null, false);
} }
private TaskManager setUpTaskManager(final ProcessingMode processingMode, final boolean stateUpdaterEnabled) { private TaskManager setUpTaskManager(final ProcessingMode processingMode, final boolean stateUpdaterEnabled) {
return setUpTaskManager(processingMode, null, stateUpdaterEnabled); return setUpTaskManager(processingMode, null, stateUpdaterEnabled, false);
}
private TaskManager setUpTaskManager(final ProcessingMode processingMode, final TasksRegistry tasks, final boolean stateUpdaterEnabled) {
return setUpTaskManager(processingMode, tasks, stateUpdaterEnabled, false);
} }
private TaskManager setUpTaskManager(final ProcessingMode processingMode, private TaskManager setUpTaskManager(final ProcessingMode processingMode,
final TasksRegistry tasks, final TasksRegistry tasks,
final boolean stateUpdaterEnabled) { final boolean stateUpdaterEnabled,
final boolean processingThreadsEnabled) {
topologyMetadata = new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode)); topologyMetadata = new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode));
final TaskManager taskManager = new TaskManager( final TaskManager taskManager = new TaskManager(
time, time,
@ -234,7 +242,8 @@ public class TaskManagerTest {
topologyMetadata, topologyMetadata,
adminClient, adminClient,
stateDirectory, stateDirectory,
stateUpdaterEnabled ? stateUpdater : null stateUpdaterEnabled ? stateUpdater : null,
processingThreadsEnabled ? schedulingTaskManager : null
); );
taskManager.setMainConsumer(consumer); taskManager.setMainConsumer(consumer);
return taskManager; return taskManager;
@ -287,6 +296,103 @@ public class TaskManagerTest {
Mockito.verify(standbyTask).resume(); Mockito.verify(standbyTask).resume();
} }
@Test
public void shouldLockAllTasksOnCorruptionWithProcessingThreads() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.activeTaskIds()).thenReturn(mkSet(taskId00, taskId01));
when(tasks.task(taskId00)).thenReturn(activeTask1);
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
replay(consumer);
taskManager.handleCorruption(mkSet(taskId00));
verify(consumer);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockCommitableTasksOnCorruptionWithProcessingThreads() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.commit(mkSet(activeTask1, activeTask2));
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockActiveOnHandleAssignmentWithProcessingThreads() {
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTaskIds()).thenReturn(mkSet(taskId00, taskId01));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.handleAssignment(
mkMap(mkEntry(taskId00, taskId00Partitions)),
mkMap(mkEntry(taskId01, taskId01Partitions))
);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockAffectedTasksOnHandleRevocation() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTasks()).thenReturn(mkSet(activeTask1, activeTask2));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.handleRevocation(taskId01Partitions);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockTasksOnClose() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTasks()).thenReturn(mkSet(activeTask1, activeTask2));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.closeAndCleanUpTasks(mkSet(activeTask1), mkSet(), false);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00));
}
@Test @Test
public void shouldResumePollingForPartitionsWithAvailableSpaceForAllActiveTasks() { public void shouldResumePollingForPartitionsWithAvailableSpaceForAllActiveTasks() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions) final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
@ -3553,6 +3659,16 @@ public class TaskManagerTest {
Mockito.verify(failedStatefulTask).closeDirty(); Mockito.verify(failedStatefulTask).closeDirty();
} }
@Test
public void shouldShutdownSchedulingTaskManager() {
final TasksRegistry tasks = mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
taskManager.shutdown(true);
Mockito.verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE));
}
@Test @Test
public void shouldShutDownStateUpdaterAndAddRestoredTasksToTaskRegistry() { public void shouldShutDownStateUpdaterAndAddRestoredTasksToTaskRegistry() {
final TasksRegistry tasks = mock(TasksRegistry.class); final TasksRegistry tasks = mock(TasksRegistry.class);

5
streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java

@ -460,7 +460,10 @@ public class StreamThreadStateStoreProviderTest {
new MockTime(), new MockTime(),
stateManager, stateManager,
recordCollector, recordCollector,
context, logContext); context,
logContext,
false
);
} }
private void mockThread(final boolean initialized) { private void mockThread(final boolean initialized) {

4
streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java

@ -522,7 +522,9 @@ public class TopologyTestDriver implements Closeable {
stateManager, stateManager,
recordCollector, recordCollector,
context, context,
logContext); logContext,
false
);
task.initializeIfNeeded(); task.initializeIfNeeded();
task.completeRestoration(noOpResetter -> { }); task.completeRestoration(noOpResetter -> { });
task.processorContext().setRecordContext(null); task.processorContext().setRecordContext(null);

Loading…
Cancel
Save