Browse Source

KAFKA-15326: [10/N] Integrate processing thread (#14193)

- Introduce a new internal config flag to enable processing threads
- If enabled, create a scheduling task manager inside the normal task manager (renamings will be added on top of this), and use it from the stream thread
- All operations inside the task manager that change task state, lock the corresponding tasks if processing threads are enabled.
- Adds a new abstract class AbstractPartitionGroup. We can modify the underlying implementation depending on the synchronization requirements. PartitionGroup is the unsynchronized subclass that is going to be used by the original code path. The processing thread code path uses a trivially synchronized SynchronizedPartitionGroup that uses object monitors. Further down the road, there is the opportunity to implement a weakly synchronized alternative. The details are complex, but since the implementation is essentially a queue + some other things, it should be feasible to implement this lock-free.
- Refactorings in StreamThreadTest: Make all tests use the thread member variable and add tearDown in order avoid thread leaks and simplify debugging. Make the test parameterized on two internal flags: state updater enabled and processing threads enabled. Use JUnit's assume to disable all tests that do not apply.
Enable some integration tests with processing threads enabled.

Reviewer: Bruno Cadonna <bruno@confluent.io>
pull/13873/merge
Lucas Brutschy 1 year ago committed by GitHub
parent
commit
d144b7ee38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java
  2. 11
      clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java
  3. 7
      streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
  4. 85
      streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java
  5. 12
      streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
  6. 49
      streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
  7. 42
      streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
  8. 131
      streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
  9. 101
      streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java
  10. 106
      streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
  11. 5
      streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
  12. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
  13. 11
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskExecutor.java
  14. 6
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
  15. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/TaskManager.java
  16. 2
      streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java
  17. 31
      streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
  18. 13
      streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
  19. 1
      streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
  20. 26
      streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
  21. 32
      streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
  22. 838
      streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
  23. 201
      streams/src/test/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroupTest.java
  24. 124
      streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
  25. 5
      streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
  26. 4
      streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java

27
clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java

@ -160,7 +160,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -160,7 +160,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override
public void initTransactions() {
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
if (this.transactionInitialized) {
throw new IllegalStateException("MockProducer has already been initialized for transactions.");
}
@ -176,7 +177,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -176,7 +177,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override
public void beginTransaction() throws ProducerFencedException {
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized();
if (this.beginTransactionException != null) {
@ -205,7 +207,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -205,7 +207,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
public void sendOffsetsToTransaction(Map<TopicPartition, OffsetAndMetadata> offsets,
ConsumerGroupMetadata groupMetadata) throws ProducerFencedException {
Objects.requireNonNull(groupMetadata);
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized();
verifyTransactionInFlight();
@ -224,7 +227,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -224,7 +227,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override
public void commitTransaction() throws ProducerFencedException {
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized();
verifyTransactionInFlight();
@ -249,7 +253,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -249,7 +253,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
@Override
public void abortTransaction() throws ProducerFencedException {
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized();
verifyTransactionInFlight();
@ -265,10 +270,13 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -265,10 +270,13 @@ public class MockProducer<K, V> implements Producer<K, V> {
this.transactionInFlight = false;
}
private synchronized void verifyProducerState() {
private synchronized void verifyNotClosed() {
if (this.closed) {
throw new IllegalStateException("MockProducer is already closed.");
}
}
private synchronized void verifyNotFenced() {
if (this.producerFenced) {
throw new ProducerFencedException("MockProducer is fenced.");
}
@ -288,7 +296,7 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -288,7 +296,7 @@ public class MockProducer<K, V> implements Producer<K, V> {
/**
* Adds the record to the list of sent records. The {@link RecordMetadata} returned will be immediately satisfied.
*
*
* @see #history()
*/
@Override
@ -362,7 +370,7 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -362,7 +370,7 @@ public class MockProducer<K, V> implements Producer<K, V> {
}
public synchronized void flush() {
verifyProducerState();
verifyNotClosed();
if (this.flushException != null) {
throw this.flushException;
@ -415,7 +423,8 @@ public class MockProducer<K, V> implements Producer<K, V> { @@ -415,7 +423,8 @@ public class MockProducer<K, V> implements Producer<K, V> {
}
public synchronized void fenceProducer() {
verifyProducerState();
verifyNotClosed();
verifyNotFenced();
verifyTransactionsInitialized();
this.producerFenced = true;
}

11
clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java

@ -40,6 +40,7 @@ import java.util.concurrent.Future; @@ -40,6 +40,7 @@ import java.util.concurrent.Future;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@ -702,7 +703,15 @@ public class MockProducerTest { @@ -702,7 +703,15 @@ public class MockProducerTest {
producer.close();
assertThrows(IllegalStateException.class, producer::flush);
}
@Test
public void shouldNotThrowOnFlushProducerIfProducerIsFenced() {
buildMockProducer(true);
producer.initTransactions();
producer.fenceProducer();
assertDoesNotThrow(producer::flush);
}
@Test
@SuppressWarnings("unchecked")
public void shouldThrowClassCastException() {

7
streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java

@ -1200,6 +1200,13 @@ public class StreamsConfig extends AbstractConfig { @@ -1200,6 +1200,13 @@ public class StreamsConfig extends AbstractConfig {
public static boolean getStateUpdaterEnabled(final Map<String, Object> configs) {
return InternalConfig.getBoolean(configs, InternalConfig.STATE_UPDATER_ENABLED, true);
}
// Private API to enable processing threads (i.e. polling is decoupled from processing)
public static final String PROCESSING_THREADS_ENABLED = "__processing.threads.enabled__";
public static boolean getProcessingThreadsEnabled(final Map<String, Object> configs) {
return InternalConfig.getBoolean(configs, InternalConfig.PROCESSING_THREADS_ENABLED, false);
}
public static boolean getBoolean(final Map<String, Object> configs, final String key, final boolean defaultValue) {
final Object value = configs.getOrDefault(key, defaultValue);

85
streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractPartitionGroup.java

@ -0,0 +1,85 @@ @@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import java.util.Set;
import java.util.function.Function;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
abstract class AbstractPartitionGroup {
abstract boolean readyToProcess(long wallClockTime);
// creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions
abstract void updatePartitions(Set<TopicPartition> inputPartitions, Function<TopicPartition, RecordQueue> recordQueueCreator);
abstract void setPartitionTime(TopicPartition partition, long partitionTime);
/**
* Get the next record and queue
*
* @return StampedRecord
*/
abstract StampedRecord nextRecord(RecordInfo info, long wallClockTime);
/**
* Adds raw records to this partition group
*
* @param partition the partition
* @param rawRecords the raw records
* @return the queue size for the partition
*/
abstract int addRawRecords(TopicPartition partition, Iterable<ConsumerRecord<byte[], byte[]>> rawRecords);
abstract long partitionTimestamp(final TopicPartition partition);
/**
* Return the stream-time of this partition group defined as the largest timestamp seen across all partitions
*/
abstract long streamTime();
abstract Long headRecordOffset(final TopicPartition partition);
abstract int numBuffered();
abstract int numBuffered(TopicPartition tp);
abstract void clear();
abstract void updateLags();
abstract void close();
abstract Set<TopicPartition> partitions();
static class RecordInfo {
RecordQueue queue;
ProcessorNode<?, ?, ?, ?> node() {
return queue.source();
}
TopicPartition partition() {
return queue.partition();
}
RecordQueue queue() {
return queue;
}
}
}

12
streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java

@ -66,6 +66,7 @@ class ActiveTaskCreator { @@ -66,6 +66,7 @@ class ActiveTaskCreator {
private final Map<TaskId, StreamsProducer> taskProducers;
private final ProcessingMode processingMode;
private final boolean stateUpdaterEnabled;
private final boolean processingThreadsEnabled;
ActiveTaskCreator(final TopologyMetadata topologyMetadata,
final StreamsConfig applicationConfig,
@ -78,7 +79,9 @@ class ActiveTaskCreator { @@ -78,7 +79,9 @@ class ActiveTaskCreator {
final String threadId,
final UUID processId,
final Logger log,
final boolean stateUpdaterEnabled) {
final boolean stateUpdaterEnabled,
final boolean processingThreadsEnabled
) {
this.topologyMetadata = topologyMetadata;
this.applicationConfig = applicationConfig;
this.streamsMetrics = streamsMetrics;
@ -90,6 +93,7 @@ class ActiveTaskCreator { @@ -90,6 +93,7 @@ class ActiveTaskCreator {
this.threadId = threadId;
this.log = log;
this.stateUpdaterEnabled = stateUpdaterEnabled;
this.processingThreadsEnabled = processingThreadsEnabled;
createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
processingMode = processingMode(applicationConfig);
@ -242,7 +246,8 @@ class ActiveTaskCreator { @@ -242,7 +246,8 @@ class ActiveTaskCreator {
standbyTask.stateMgr,
recordCollector,
standbyTask.processorContext,
standbyTask.logContext
standbyTask.logContext,
processingThreadsEnabled
);
log.trace("Created active task {} from recycled standby task with assigned partitions {}", task.id, inputPartitions);
@ -272,7 +277,8 @@ class ActiveTaskCreator { @@ -272,7 +277,8 @@ class ActiveTaskCreator {
stateManager,
recordCollector,
context,
logContext
logContext,
processingThreadsEnabled
);
log.trace("Created active task {} with assigned partitions {}", taskId, inputPartitions);

49
streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java

@ -56,7 +56,7 @@ import java.util.function.Function; @@ -56,7 +56,7 @@ import java.util.function.Function;
* As a consequence of the definition, the PartitionGroup's stream-time is non-decreasing
* (i.e., it increases or stays the same over time).
*/
public class PartitionGroup {
class PartitionGroup extends AbstractPartitionGroup {
private final Logger logger;
private final Map<TopicPartition, RecordQueue> partitionQueues;
@ -72,22 +72,6 @@ public class PartitionGroup { @@ -72,22 +72,6 @@ public class PartitionGroup {
private final Map<TopicPartition, Long> idlePartitionDeadlines = new HashMap<>();
private final Map<TopicPartition, Long> fetchedLags = new HashMap<>();
static class RecordInfo {
RecordQueue queue;
ProcessorNode<?, ?, ?, ?> node() {
return queue.source();
}
TopicPartition partition() {
return queue.partition();
}
RecordQueue queue() {
return queue;
}
}
PartitionGroup(final LogContext logContext,
final Map<TopicPartition, RecordQueue> partitionQueues,
final Function<TopicPartition, OptionalLong> lagProvider,
@ -106,7 +90,8 @@ public class PartitionGroup { @@ -106,7 +90,8 @@ public class PartitionGroup {
streamTime = RecordQueue.UNKNOWN;
}
public boolean readyToProcess(final long wallClockTime) {
@Override
boolean readyToProcess(final long wallClockTime) {
if (maxTaskIdleMs == StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) {
if (logger.isTraceEnabled() && !allBuffered && totalBuffered > 0) {
final Set<TopicPartition> bufferedPartitions = new HashSet<>();
@ -209,7 +194,7 @@ public class PartitionGroup { @@ -209,7 +194,7 @@ public class PartitionGroup {
}
}
// visible for testing
@Override
long partitionTimestamp(final TopicPartition partition) {
final RecordQueue queue = partitionQueues.get(partition);
if (queue == null) {
@ -219,6 +204,7 @@ public class PartitionGroup { @@ -219,6 +204,7 @@ public class PartitionGroup {
}
// creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions
@Override
void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
final Set<TopicPartition> removedPartitions = new HashSet<>();
final Set<TopicPartition> newInputPartitions = new HashSet<>(inputPartitions);
@ -241,6 +227,7 @@ public class PartitionGroup { @@ -241,6 +227,7 @@ public class PartitionGroup {
allBuffered = allBuffered && newInputPartitions.isEmpty();
}
@Override
void setPartitionTime(final TopicPartition partition, final long partitionTime) {
final RecordQueue queue = partitionQueues.get(partition);
if (queue == null) {
@ -252,11 +239,7 @@ public class PartitionGroup { @@ -252,11 +239,7 @@ public class PartitionGroup {
queue.setPartitionTime(partitionTime);
}
/**
* Get the next record and queue
*
* @return StampedRecord
*/
@Override
StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) {
StampedRecord record = null;
@ -290,13 +273,7 @@ public class PartitionGroup { @@ -290,13 +273,7 @@ public class PartitionGroup {
return record;
}
/**
* Adds raw records to this partition group
*
* @param partition the partition
* @param rawRecords the raw records
* @return the queue size for the partition
*/
@Override
int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
final RecordQueue recordQueue = partitionQueues.get(partition);
@ -328,13 +305,12 @@ public class PartitionGroup { @@ -328,13 +305,12 @@ public class PartitionGroup {
return Collections.unmodifiableSet(partitionQueues.keySet());
}
/**
* Return the stream-time of this partition group defined as the largest timestamp seen across all partitions
*/
@Override
long streamTime() {
return streamTime;
}
@Override
Long headRecordOffset(final TopicPartition partition) {
final RecordQueue recordQueue = partitionQueues.get(partition);
@ -348,6 +324,7 @@ public class PartitionGroup { @@ -348,6 +324,7 @@ public class PartitionGroup {
/**
* @throws IllegalStateException if the record's partition does not belong to this partition group
*/
@Override
int numBuffered(final TopicPartition partition) {
final RecordQueue recordQueue = partitionQueues.get(partition);
@ -358,6 +335,7 @@ public class PartitionGroup { @@ -358,6 +335,7 @@ public class PartitionGroup {
return recordQueue.size();
}
@Override
int numBuffered() {
return totalBuffered;
}
@ -367,6 +345,7 @@ public class PartitionGroup { @@ -367,6 +345,7 @@ public class PartitionGroup {
return allBuffered;
}
@Override
void clear() {
for (final RecordQueue queue : partitionQueues.values()) {
queue.clear();
@ -377,12 +356,14 @@ public class PartitionGroup { @@ -377,12 +356,14 @@ public class PartitionGroup {
fetchedLags.clear();
}
@Override
void close() {
for (final RecordQueue queue : partitionQueues.values()) {
queue.close();
}
}
@Override
void updateLags() {
if (maxTaskIdleMs != StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) {
for (final TopicPartition tp : partitionQueues.keySet()) {

42
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java

@ -39,6 +39,7 @@ import org.apache.kafka.streams.processor.Punctuator; @@ -39,6 +39,7 @@ import org.apache.kafka.streams.processor.Punctuator;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.TimestampExtractor;
import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics;
import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
@ -64,7 +65,7 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric @@ -64,7 +65,7 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
/**
* A StreamTask is associated with a {@link PartitionGroup}, and is assigned to a StreamThread for processing.
* A StreamTask is associated with a {@link AbstractPartitionGroup}, and is assigned to a StreamThread for processing.
*/
public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, Task {
@ -77,9 +78,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, @@ -77,9 +78,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
private final boolean eosEnabled;
private final int maxBufferedSize;
private final PartitionGroup partitionGroup;
private final AbstractPartitionGroup partitionGroup;
private final RecordCollector recordCollector;
private final PartitionGroup.RecordInfo recordInfo;
private final AbstractPartitionGroup.RecordInfo recordInfo;
private final Map<TopicPartition, Long> consumedOffsets;
private final Map<TopicPartition, Long> committedOffsets;
private final Map<TopicPartition, Long> highWatermark;
@ -111,7 +112,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, @@ -111,7 +112,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
private boolean hasPendingTxCommit = false;
private Optional<Long> timeCurrentIdlingStarted;
@SuppressWarnings({"rawtypes", "this-escape"})
@SuppressWarnings({"rawtypes", "this-escape", "checkstyle:ParameterNumber"})
public StreamTask(final TaskId id,
final Set<TopicPartition> inputPartitions,
final ProcessorTopology topology,
@ -124,7 +125,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, @@ -124,7 +125,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
final ProcessorStateManager stateMgr,
final RecordCollector recordCollector,
final InternalProcessorContext processorContext,
final LogContext logContext) {
final LogContext logContext,
final boolean processingThreadsEnabled
) {
super(
id,
topology,
@ -181,19 +184,30 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, @@ -181,19 +184,30 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
recordQueueCreator = new RecordQueueCreator(this.logContext, config.timestampExtractor, config.deserializationExceptionHandler);
recordInfo = new PartitionGroup.RecordInfo();
recordInfo = new RecordInfo();
final Sensor enforcedProcessingSensor;
enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics);
final long maxTaskIdleMs = config.maxTaskIdleMs;
partitionGroup = new PartitionGroup(
logContext,
createPartitionQueues(),
mainConsumer::currentLag,
TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
enforcedProcessingSensor,
maxTaskIdleMs
);
if (processingThreadsEnabled) {
partitionGroup = new SynchronizedPartitionGroup(new PartitionGroup(
logContext,
createPartitionQueues(),
mainConsumer::currentLag,
TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
enforcedProcessingSensor,
maxTaskIdleMs
));
} else {
partitionGroup = new PartitionGroup(
logContext,
createPartitionQueues(),
mainConsumer::currentLag,
TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
enforcedProcessingSensor,
maxTaskIdleMs
);
}
stateMgr.registerGlobalStateStores(topology.globalStateStores());
committedOffsets = new HashMap<>();

131
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java

@ -49,6 +49,8 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorError; @@ -49,6 +49,8 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager.DefaultTaskExecutorCreator;
import org.apache.kafka.streams.state.internals.ThreadCache;
import java.util.Queue;
@ -331,6 +333,7 @@ public class StreamThread extends Thread { @@ -331,6 +333,7 @@ public class StreamThread extends Thread {
private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false);
private final boolean eosEnabled;
private final boolean stateUpdaterEnabled;
private final boolean processingThreadsEnabled;
public static StreamThread create(final TopologyMetadata topologyMetadata,
final StreamsConfig config,
@ -375,6 +378,7 @@ public class StreamThread extends Thread { @@ -375,6 +378,7 @@ public class StreamThread extends Thread {
final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics);
final boolean stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals());
final boolean proceessingThreadsEnabled = InternalConfig.getProcessingThreadsEnabled(config.originals());
final ActiveTaskCreator activeTaskCreator = new ActiveTaskCreator(
topologyMetadata,
config,
@ -387,7 +391,9 @@ public class StreamThread extends Thread { @@ -387,7 +391,9 @@ public class StreamThread extends Thread {
threadId,
processId,
log,
stateUpdaterEnabled);
stateUpdaterEnabled,
proceessingThreadsEnabled
);
final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator(
topologyMetadata,
config,
@ -398,6 +404,15 @@ public class StreamThread extends Thread { @@ -398,6 +404,15 @@ public class StreamThread extends Thread {
log,
stateUpdaterEnabled);
final Tasks tasks = new Tasks(new LogContext(logPrefix));
final boolean processingThreadsEnabled =
InternalConfig.getProcessingThreadsEnabled(config.originals());
final DefaultTaskManager schedulingTaskManager =
maybeCreateSchedulingTaskManager(processingThreadsEnabled, stateUpdaterEnabled, topologyMetadata, time, threadId, tasks);
final StateUpdater stateUpdater =
maybeCreateAndStartStateUpdater(stateUpdaterEnabled, streamsMetrics, config, changelogReader, topologyMetadata, time, clientId, threadIdx);
final TaskManager taskManager = new TaskManager(
time,
changelogReader,
@ -405,11 +420,12 @@ public class StreamThread extends Thread { @@ -405,11 +420,12 @@ public class StreamThread extends Thread {
logPrefix,
activeTaskCreator,
standbyTaskCreator,
new Tasks(new LogContext(logPrefix)),
tasks,
topologyMetadata,
adminClient,
stateDirectory,
maybeCreateAndStartStateUpdater(stateUpdaterEnabled, streamsMetrics, config, changelogReader, topologyMetadata, time, clientId, threadIdx)
stateUpdater,
schedulingTaskManager
);
referenceContainer.taskManager = taskManager;
@ -452,6 +468,31 @@ public class StreamThread extends Thread { @@ -452,6 +468,31 @@ public class StreamThread extends Thread {
return streamThread.updateThreadMetadata(getSharedAdminClientId(clientId));
}
private static DefaultTaskManager maybeCreateSchedulingTaskManager(final boolean processingThreadsEnabled,
final boolean stateUpdaterEnabled,
final TopologyMetadata topologyMetadata,
final Time time,
final String threadId,
final Tasks tasks) {
if (processingThreadsEnabled) {
if (!stateUpdaterEnabled) {
throw new IllegalStateException("Processing threads require the state updater to be enabled");
}
final DefaultTaskManager defaultTaskManager = new DefaultTaskManager(
time,
threadId,
tasks,
new DefaultTaskExecutorCreator(),
topologyMetadata.taskExecutionMetadata(),
1
);
defaultTaskManager.startTaskExecutors();
return defaultTaskManager;
}
return null;
}
private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled,
final StreamsMetricsImpl streamsMetrics,
final StreamsConfig streamsConfig,
@ -488,7 +529,8 @@ public class StreamThread extends Thread { @@ -488,7 +529,8 @@ public class StreamThread extends Thread {
final Queue<StreamsException> nonFatalExceptionsToHandle,
final Runnable shutdownErrorHook,
final BiConsumer<Throwable, Boolean> streamsUncaughtExceptionHandler,
final java.util.function.Consumer<Long> cacheResizer) {
final java.util.function.Consumer<Long> cacheResizer
) {
super(threadId);
this.stateLock = new Object();
this.adminClient = adminClient;
@ -558,6 +600,7 @@ public class StreamThread extends Thread { @@ -558,6 +600,7 @@ public class StreamThread extends Thread {
this.numIterations = 1;
this.eosEnabled = eosEnabled(config);
this.stateUpdaterEnabled = InternalConfig.getStateUpdaterEnabled(config.originals());
this.processingThreadsEnabled = InternalConfig.getProcessingThreadsEnabled(config.originals());
}
private static final class InternalConsumerConfig extends ConsumerConfig {
@ -620,7 +663,11 @@ public class StreamThread extends Thread { @@ -620,7 +663,11 @@ public class StreamThread extends Thread {
if (size != -1L) {
cacheResizer.accept(size);
}
runOnce();
if (processingThreadsEnabled) {
runOnceWithProcessingThreads();
} else {
runOnceWithoutProcessingThreads();
}
// Check for a scheduled rebalance but don't trigger it until the current rebalance is done
if (!taskManager.rebalanceInProgress() && nextProbingRebalanceMs.get() < time.milliseconds()) {
@ -765,7 +812,7 @@ public class StreamThread extends Thread { @@ -765,7 +812,7 @@ public class StreamThread extends Thread {
* or if the task producer got fenced (EOS)
*/
// Visible for testing
void runOnce() {
void runOnceWithoutProcessingThreads() {
final long startMs = time.milliseconds();
now = startMs;
@ -900,6 +947,78 @@ public class StreamThread extends Thread { @@ -900,6 +947,78 @@ public class StreamThread extends Thread {
}
}
/**
* One iteration of a thread includes the following steps:
*
* 1. poll records from main consumer and add to buffer;
* 2. check the task manager for any exceptions to be handled
* 3. commit all tasks if necessary;
*
* @throws IllegalStateException If store gets registered after initialized is already finished
* @throws StreamsException If the store's change log does not contain the partition
* @throws TaskMigratedException If another thread wrote to the changelog topic that is currently restored
* or if committing offsets failed (non-EOS)
* or if the task producer got fenced (EOS)
*/
// Visible for testing
void runOnceWithProcessingThreads() {
final long startMs = time.milliseconds();
now = startMs;
final long pollLatency;
taskManager.resumePollingForPartitionsWithAvailableSpace();
try {
pollLatency = pollPhase();
} finally {
taskManager.updateLags();
}
// Shutdown hook could potentially be triggered and transit the thread state to PENDING_SHUTDOWN during #pollRequests().
// The task manager internal states could be uninitialized if the state transition happens during #onPartitionsAssigned().
// Should only proceed when the thread is still running after #pollRequests(), because no external state mutation
// could affect the task manager state beyond this point within #runOnce().
if (!isRunning()) {
log.debug("Thread state is already {}, skipping the run once call after poll request", state);
return;
}
long totalCommitLatency = 0L;
if (isRunning()) {
checkStateUpdater();
taskManager.maybeThrowTaskExceptionsFromProcessingThreads();
taskManager.signalTaskExecutors();
final long beforeCommitMs = now;
final int committed = maybeCommit();
final long commitLatency = Math.max(now - beforeCommitMs, 0);
totalCommitLatency += commitLatency;
if (committed > 0) {
totalCommittedSinceLastSummary += committed;
commitSensor.record(commitLatency / (double) committed, now);
if (log.isDebugEnabled()) {
log.debug("Committed all active tasks {} and standby tasks {} in {}ms",
taskManager.activeTaskIds(), taskManager.standbyTaskIds(), commitLatency);
}
}
}
now = time.milliseconds();
final long runOnceLatency = now - startMs;
pollRatioSensor.record((double) pollLatency / runOnceLatency, now);
commitRatioSensor.record((double) totalCommitLatency / runOnceLatency, now);
final boolean logProcessingSummary = now - lastLogSummaryMs > LOG_SUMMARY_INTERVAL_MS;
if (logProcessingSummary) {
log.info("Committed {} total tasks since the last update", totalCommittedSinceLastSummary);
totalCommittedSinceLastSummary = 0L;
lastLogSummaryMs = now;
}
}
private void initializeAndRestorePhase() {
final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = partitions -> resetOffsets(partitions, null);
final State stateSnapshot = state;

101
streams/src/main/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroup.java

@ -0,0 +1,101 @@ @@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import java.util.Set;
import java.util.function.Function;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
class SynchronizedPartitionGroup extends AbstractPartitionGroup {
private final AbstractPartitionGroup wrapped;
public SynchronizedPartitionGroup(final AbstractPartitionGroup wrapped) {
this.wrapped = wrapped;
}
@Override
synchronized boolean readyToProcess(final long wallClockTime) {
return wrapped.readyToProcess(wallClockTime);
}
@Override
synchronized void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
wrapped.updatePartitions(inputPartitions, recordQueueCreator);
}
@Override
synchronized void setPartitionTime(final TopicPartition partition, final long partitionTime) {
wrapped.setPartitionTime(partition, partitionTime);
}
@Override
synchronized StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) {
return wrapped.nextRecord(info, wallClockTime);
}
@Override
synchronized int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
return wrapped.addRawRecords(partition, rawRecords);
}
@Override
synchronized long partitionTimestamp(final TopicPartition partition) {
return wrapped.partitionTimestamp(partition);
}
@Override
synchronized long streamTime() {
return wrapped.streamTime();
}
@Override
synchronized Long headRecordOffset(final TopicPartition partition) {
return wrapped.headRecordOffset(partition);
}
@Override
synchronized int numBuffered() {
return wrapped.numBuffered();
}
@Override
synchronized int numBuffered(final TopicPartition tp) {
return wrapped.numBuffered(tp);
}
@Override
synchronized void clear() {
wrapped.clear();
}
@Override
synchronized void updateLags() {
wrapped.updateLags();
}
@Override
synchronized void close() {
wrapped.close();
}
@Override
synchronized Set<TopicPartition> partitions() {
return wrapped.partitions();
}
}

106
streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
*/
package org.apache.kafka.streams.processor.internals;
import java.util.concurrent.ExecutionException;
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.DeleteRecordsResult;
import org.apache.kafka.clients.admin.RecordsToDelete;
@ -39,6 +40,7 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode; @@ -39,6 +40,7 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
import org.apache.kafka.streams.processor.internals.Task.State;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
import org.slf4j.Logger;
@ -97,6 +99,7 @@ public class TaskManager { @@ -97,6 +99,7 @@ public class TaskManager {
private final ActiveTaskCreator activeTaskCreator;
private final StandbyTaskCreator standbyTaskCreator;
private final StateUpdater stateUpdater;
private final DefaultTaskManager schedulingTaskManager;
TaskManager(final Time time,
final ChangelogReader changelogReader,
@ -108,7 +111,9 @@ public class TaskManager { @@ -108,7 +111,9 @@ public class TaskManager {
final TopologyMetadata topologyMetadata,
final Admin adminClient,
final StateDirectory stateDirectory,
final StateUpdater stateUpdater) {
final StateUpdater stateUpdater,
final DefaultTaskManager schedulingTaskManager
) {
this.time = time;
this.processId = processId;
this.logPrefix = logPrefix;
@ -124,6 +129,7 @@ public class TaskManager { @@ -124,6 +129,7 @@ public class TaskManager {
this.log = logContext.logger(getClass());
this.stateUpdater = stateUpdater;
this.schedulingTaskManager = schedulingTaskManager;
this.tasks = tasks;
this.taskExecutor = new TaskExecutor(
this.tasks,
@ -202,6 +208,11 @@ public class TaskManager { @@ -202,6 +208,11 @@ public class TaskManager {
* @throws TaskMigratedException
*/
boolean handleCorruption(final Set<TaskId> corruptedTasks) {
final Set<TaskId> activeTasks = new HashSet<>(tasks.activeTaskIds());
// We need to stop all processing, since we need to commit non-corrupted tasks as well.
maybeLockTasks(activeTasks);
final Set<Task> corruptedActiveTasks = new HashSet<>();
final Set<Task> corruptedStandbyTasks = new HashSet<>();
@ -240,6 +251,9 @@ public class TaskManager { @@ -240,6 +251,9 @@ public class TaskManager {
}
closeDirtyAndRevive(corruptedActiveTasks, true);
maybeUnlockTasks(activeTasks);
return !corruptedActiveTasks.isEmpty();
}
@ -329,6 +343,13 @@ public class TaskManager { @@ -329,6 +343,13 @@ public class TaskManager {
final Map<Task, Set<TopicPartition>> tasksToRecycle = new HashMap<>();
final Set<Task> tasksToCloseClean = new TreeSet<>(Comparator.comparing(Task::id));
final Set<TaskId> tasksToLock =
tasks.allTaskIds().stream()
.filter(x -> activeTasksToCreate.containsKey(x) || standbyTasksToCreate.containsKey(x))
.collect(Collectors.toSet());
maybeLockTasks(tasksToLock);
// first put aside those unrecognized tasks because of unknown named-topologies
tasks.clearPendingTasksToCreate();
tasks.addPendingActiveTasksToCreate(pendingTasksToCreate(activeTasksToCreate));
@ -346,6 +367,8 @@ public class TaskManager { @@ -346,6 +367,8 @@ public class TaskManager {
final Map<TaskId, RuntimeException> taskCloseExceptions = closeAndRecycleTasks(tasksToRecycle, tasksToCloseClean);
maybeUnlockTasks(tasksToLock);
maybeThrowTaskExceptions(taskCloseExceptions);
createNewTasks(activeTasksToCreate, standbyTasksToCreate);
@ -964,6 +987,9 @@ public class TaskManager { @@ -964,6 +987,9 @@ public class TaskManager {
final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsPerTask = new HashMap<>();
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
final Set<TaskId> lockedTaskIds = activeRunningTaskIterable().stream().map(Task::id).collect(Collectors.toSet());
maybeLockTasks(lockedTaskIds);
for (final Task task : activeRunningTaskIterable()) {
if (remainingRevokedPartitions.containsAll(task.inputPartitions())) {
// when the task input partitions are included in the revoked list,
@ -1057,6 +1083,8 @@ public class TaskManager { @@ -1057,6 +1083,8 @@ public class TaskManager {
}
}
maybeUnlockTasks(lockedTaskIds);
if (firstException.get() != null) {
throw firstException.get();
}
@ -1112,6 +1140,8 @@ public class TaskManager { @@ -1112,6 +1140,8 @@ public class TaskManager {
private void closeRunningTasksDirty() {
final Set<Task> allTask = tasks.allTasks();
final Set<TaskId> allTaskIds = tasks.allTaskIds();
maybeLockTasks(allTaskIds);
for (final Task task : allTask) {
// Even though we've apparently dropped out of the group, we can continue safely to maintain our
// standby tasks while we rejoin.
@ -1119,6 +1149,7 @@ public class TaskManager { @@ -1119,6 +1149,7 @@ public class TaskManager {
closeTaskDirty(task, true);
}
}
maybeUnlockTasks(allTaskIds);
}
private void removeLostActiveTasksFromStateUpdater() {
@ -1136,6 +1167,9 @@ public class TaskManager { @@ -1136,6 +1167,9 @@ public class TaskManager {
if (stateUpdater != null) {
stateUpdater.signalResume();
}
if (schedulingTaskManager != null) {
schedulingTaskManager.signalTaskExecutors();
}
}
/**
@ -1307,6 +1341,7 @@ public class TaskManager { @@ -1307,6 +1341,7 @@ public class TaskManager {
void shutdown(final boolean clean) {
shutdownStateUpdater();
shutdownSchedulingTaskManager();
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
@ -1356,6 +1391,12 @@ public class TaskManager { @@ -1356,6 +1391,12 @@ public class TaskManager {
}
}
private void shutdownSchedulingTaskManager() {
if (schedulingTaskManager != null) {
schedulingTaskManager.shutdown(Duration.ofMillis(Long.MAX_VALUE));
}
}
private void closeFailedTasksFromStateUpdater() {
final Set<Task> tasksToCloseDirty = stateUpdater.drainExceptionsAndFailedTasks().stream()
.flatMap(exAndTasks -> exAndTasks.getTasks().stream()).collect(Collectors.toSet());
@ -1415,6 +1456,12 @@ public class TaskManager { @@ -1415,6 +1456,12 @@ public class TaskManager {
void closeAndCleanUpTasks(final Collection<Task> activeTasks, final Collection<Task> standbyTasks, final boolean clean) {
final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
final Set<TaskId> ids =
activeTasks.stream()
.map(Task::id)
.collect(Collectors.toSet());
maybeLockTasks(ids);
final Set<Task> tasksToCloseDirty = new HashSet<>();
tasksToCloseDirty.addAll(tryCloseCleanActiveTasks(activeTasks, clean, firstException));
tasksToCloseDirty.addAll(tryCloseCleanStandbyTasks(standbyTasks, clean, firstException));
@ -1423,6 +1470,8 @@ public class TaskManager { @@ -1423,6 +1470,8 @@ public class TaskManager {
closeTaskDirty(task, true);
}
maybeUnlockTasks(ids);
final RuntimeException exception = firstException.get();
if (exception != null) {
throw exception;
@ -1681,6 +1730,16 @@ public class TaskManager { @@ -1681,6 +1730,16 @@ public class TaskManager {
}
}
/**
* Wake-up any sleeping processing threads.
*/
public void signalTaskExecutors() {
if (schedulingTaskManager != null) {
// Wake up sleeping task executors after every poll, in case there is processing or punctuation to-do.
schedulingTaskManager.signalTaskExecutors();
}
}
/**
* Take records and add them to each respective task
*
@ -1700,6 +1759,42 @@ public class TaskManager { @@ -1700,6 +1759,42 @@ public class TaskManager {
}
}
private void maybeLockTasks(final Set<TaskId> ids) {
if (schedulingTaskManager != null && !ids.isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("Locking tasks {}", ids.stream().map(TaskId::toString).collect(Collectors.joining(", ")));
}
boolean locked = false;
while (!locked) {
try {
schedulingTaskManager.lockTasks(ids).get();
locked = true;
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for tasks {} to be locked",
ids.stream().map(TaskId::toString).collect(Collectors.joining(",")));
} catch (final ExecutionException e) {
log.info("Failed to lock tasks");
throw new RuntimeException(e);
}
}
}
}
private void maybeUnlockTasks(final Set<TaskId> ids) {
if (schedulingTaskManager != null && !ids.isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("Unlocking tasks {}", ids.stream().map(TaskId::toString).collect(Collectors.joining(", ")));
}
schedulingTaskManager.unlockTasks(ids);
}
}
public void maybeThrowTaskExceptionsFromProcessingThreads() {
if (schedulingTaskManager != null) {
maybeThrowTaskExceptions(schedulingTaskManager.drainUncaughtExceptions());
}
}
/**
* @throws TaskMigratedException if committing offsets failed (non-EOS)
* or if the task producer got fenced (EOS)
@ -1709,6 +1804,14 @@ public class TaskManager { @@ -1709,6 +1804,14 @@ public class TaskManager {
*/
int commit(final Collection<Task> tasksToCommit) {
int committed = 0;
final Set<TaskId> ids =
tasksToCommit.stream()
.map(Task::id)
.collect(Collectors.toSet());
maybeLockTasks(ids);
// We have to throw the first uncaught exception after locking the tasks, to not attempt to commit failure records.
maybeThrowTaskExceptionsFromProcessingThreads();
final Map<Task, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
try {
@ -1719,6 +1822,7 @@ public class TaskManager { @@ -1719,6 +1822,7 @@ public class TaskManager {
.forEach(t -> t.maybeInitTaskTimeoutOrThrow(time.milliseconds(), timeoutException));
}
maybeUnlockTasks(ids);
return committed;
}

5
streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java

@ -349,6 +349,11 @@ class Tasks implements TasksRegistry { @@ -349,6 +349,11 @@ class Tasks implements TasksRegistry {
return tasks;
}
@Override
public synchronized Collection<TaskId> activeTaskIds() {
return Collections.unmodifiableCollection(activeTasksPerId.keySet());
}
@Override
public synchronized Collection<Task> activeTasks() {
return Collections.unmodifiableCollection(activeTasksPerId.values());

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java

@ -85,6 +85,8 @@ public interface TasksRegistry { @@ -85,6 +85,8 @@ public interface TasksRegistry {
Collection<Task> tasks(final Collection<TaskId> taskIds);
Collection<TaskId> activeTaskIds();
Collection<Task> activeTasks();
Set<Task> allTasks();

11
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskExecutor.java

@ -194,7 +194,16 @@ public class DefaultTaskExecutor implements TaskExecutor { @@ -194,7 +194,16 @@ public class DefaultTaskExecutor implements TaskExecutor {
// flush the task before giving it back to task manager, if we are not handing it back because of an error.
if (!taskManager.hasUncaughtException(currentTask.id())) {
currentTask.flush();
try {
currentTask.flush();
} catch (final StreamsException e) {
log.error(String.format("Failed to flush stream task %s due to the following error:", currentTask.id()), e);
e.setTaskId(currentTask.id());
taskManager.setUncaughtException(e, currentTask.id());
} catch (final RuntimeException e) {
log.error(String.format("Failed to flush stream task %s due to the following error:", currentTask.id()), e);
taskManager.setUncaughtException(new StreamsException(e, currentTask.id()), currentTask.id());
}
}
taskManager.unassignTask(currentTask, DefaultTaskExecutor.this);

6
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java

@ -319,9 +319,9 @@ public class DefaultTaskManager implements TaskManager { @@ -319,9 +319,9 @@ public class DefaultTaskManager implements TaskManager {
exception.getMessage());
}
public Map<TaskId, StreamsException> drainUncaughtExceptions() {
final Map<TaskId, StreamsException> returnValue = returnWithTasksLocked(() -> {
final Map<TaskId, StreamsException> result = new HashMap<>(uncaughtExceptions);
public Map<TaskId, RuntimeException> drainUncaughtExceptions() {
final Map<TaskId, RuntimeException> returnValue = returnWithTasksLocked(() -> {
final Map<TaskId, RuntimeException> result = new HashMap<>(uncaughtExceptions);
uncaughtExceptions.clear();
return result;
});

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/TaskManager.java

@ -118,7 +118,7 @@ public interface TaskManager { @@ -118,7 +118,7 @@ public interface TaskManager {
*
* @return A map from task ID to the exception that occurred.
*/
Map<TaskId, StreamsException> drainUncaughtExceptions();
Map<TaskId, RuntimeException> drainUncaughtExceptions();
/**
* Can be used to check if a specific task has an uncaught exception.

2
streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java

@ -442,7 +442,7 @@ public class AdjustStreamThreadCountTest { @@ -442,7 +442,7 @@ public class AdjustStreamThreadCountTest {
@Override
public void init(final ProcessorContext context) {
context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> {
if (Thread.currentThread().getName().endsWith("StreamThread-1") && injectError.get()) {
if (Thread.currentThread().getName().contains("StreamThread-1") && injectError.get()) {
injectError.set(false);
throw new RuntimeException("BOOM");
}

31
streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java

@ -34,6 +34,7 @@ import org.apache.kafka.streams.KafkaStreams; @@ -34,6 +34,7 @@ import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.StreamsConfig.InternalConfig;
import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
import org.apache.kafka.streams.kstream.KStream;
@ -149,18 +150,24 @@ public class EosIntegrationTest { @@ -149,18 +150,24 @@ public class EosIntegrationTest {
private String stateTmpDir;
@SuppressWarnings("deprecation")
@Parameters(name = "{0}")
public static Collection<String[]> data() {
return Arrays.asList(new String[][]{
{StreamsConfig.AT_LEAST_ONCE},
{StreamsConfig.EXACTLY_ONCE},
{StreamsConfig.EXACTLY_ONCE_V2}
@Parameters(name = "{0}, processing threads = {1}")
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
{StreamsConfig.AT_LEAST_ONCE, false},
{StreamsConfig.EXACTLY_ONCE, false},
{StreamsConfig.EXACTLY_ONCE_V2, false},
{StreamsConfig.AT_LEAST_ONCE, true},
{StreamsConfig.EXACTLY_ONCE, true},
{StreamsConfig.EXACTLY_ONCE_V2, true}
});
}
@Parameter
@Parameter(0)
public String eosConfig;
@Parameter(1)
public boolean processingThreadsEnabled;
@Before
public void createTopics() throws Exception {
applicationId = "appId-" + TEST_NUMBER.getAndIncrement();
@ -876,10 +883,15 @@ public class EosIntegrationTest { @@ -876,10 +883,15 @@ public class EosIntegrationTest {
LOG.info(dummyHostName + " is executing the injected stall");
stallingHost.set(dummyHostName);
while (doStall) {
final StreamThread thread = (StreamThread) Thread.currentThread();
if (thread.isInterrupted() || !thread.isRunning()) {
final Thread thread = Thread.currentThread();
if (thread.isInterrupted()) {
throw new RuntimeException("Detected we've been interrupted.");
}
if (!processingThreadsEnabled) {
if (!((StreamThread) thread).isRunning()) {
throw new RuntimeException("Detected we've been interrupted.");
}
}
try {
Thread.sleep(100);
} catch (final InterruptedException e) {
@ -943,6 +955,7 @@ public class EosIntegrationTest { @@ -943,6 +955,7 @@ public class EosIntegrationTest {
properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
properties.put(StreamsConfig.STATE_DIR_CONFIG, stateTmpDir + appDir);
properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, dummyHostName + ":2142");
properties.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled);
final Properties config = StreamsTestUtils.getStreamsConfig(
applicationId,

13
streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java

@ -30,6 +30,7 @@ import org.junit.jupiter.api.BeforeAll; @@ -30,6 +30,7 @@ import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.io.IOException;
@ -94,11 +95,12 @@ public class SmokeTestDriverIntegrationTest { @@ -94,11 +95,12 @@ public class SmokeTestDriverIntegrationTest {
}
private static Stream<Boolean> parameters() {
private static Stream<Arguments> parameters() {
return Stream.of(
Boolean.TRUE,
Boolean.FALSE
);
Arguments.of(false, false),
Arguments.of(true, false),
Arguments.of(true, true)
);
}
// In this test, we try to keep creating new stream, and closing the old one, to maintain only 3 streams alive.
@ -107,7 +109,7 @@ public class SmokeTestDriverIntegrationTest { @@ -107,7 +109,7 @@ public class SmokeTestDriverIntegrationTest {
// (1) 10 min timeout, (2) 30 tries of polling without getting any data
@ParameterizedTest
@MethodSource("parameters")
public void shouldWorkWithRebalance(final boolean stateUpdaterEnabled) throws InterruptedException {
public void shouldWorkWithRebalance(final boolean stateUpdaterEnabled, final boolean processingThreadsEnabled) throws InterruptedException {
Exit.setExitProcedure((statusCode, message) -> {
throw new AssertionError("Test called exit(). code:" + statusCode + " message:" + message);
});
@ -128,6 +130,7 @@ public class SmokeTestDriverIntegrationTest { @@ -128,6 +130,7 @@ public class SmokeTestDriverIntegrationTest {
final Properties props = new Properties();
props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
props.put(InternalConfig.STATE_UPDATER_ENABLED, stateUpdaterEnabled);
props.put(InternalConfig.PROCESSING_THREADS_ENABLED, processingThreadsEnabled);
// decrease the session timeout so that we can trigger the rebalance soon after old client left closed
props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000);

1
streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java

@ -497,6 +497,7 @@ public class ActiveTaskCreatorTest { @@ -497,6 +497,7 @@ public class ActiveTaskCreatorTest {
"clientId-StreamThread-0",
uuid,
new LogContext().logger(ActiveTaskCreator.class),
false,
false);
assertThat(

26
streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java

@ -33,6 +33,7 @@ import org.apache.kafka.streams.StreamsConfig; @@ -33,6 +33,7 @@ import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
import org.apache.kafka.streams.processor.TimestampExtractor;
import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.apache.kafka.test.InternalMockProcessorContext;
import org.apache.kafka.test.MockSourceNode;
import org.apache.kafka.test.MockTimestampExtractor;
@ -133,7 +134,7 @@ public class PartitionGroupTest { @@ -133,7 +134,7 @@ public class PartitionGroupTest {
private void testFirstBatch(final PartitionGroup group) {
StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
final PartitionGroup.RecordInfo info = new RecordInfo();
assertThat(group.numBuffered(), is(0));
// add three 3 records with timestamp 1, 3, 5 to partition-1
@ -193,7 +194,7 @@ public class PartitionGroupTest { @@ -193,7 +194,7 @@ public class PartitionGroupTest {
private void testSecondBatch(final PartitionGroup group) {
StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
final PartitionGroup.RecordInfo info = new RecordInfo();
// add 2 more records with timestamp 2, 4 to partition-1
final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
@ -316,7 +317,7 @@ public class PartitionGroupTest { @@ -316,7 +317,7 @@ public class PartitionGroupTest {
assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
StampedRecord record;
final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
final PartitionGroup.RecordInfo info = new RecordInfo();
// get first two records from partition 1
record = group.nextRecord(info, time.milliseconds());
@ -445,15 +446,16 @@ public class PartitionGroupTest { @@ -445,15 +446,16 @@ public class PartitionGroupTest {
new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue),
new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue));
group.addRawRecords(partition1, list);
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds());
group.nextRecord(new RecordInfo(), time.milliseconds());
group.nextRecord(new RecordInfo(), time.milliseconds());
group.updateLags();
group.clear();
assertThat(group.numBuffered(), equalTo(0));
assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), equalTo(null));
assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), equalTo(null));
assertThat(group.partitionTimestamp(partition1), equalTo(RecordQueue.UNKNOWN));
hasNoFetchedLag(group, partition1);
@ -475,7 +477,7 @@ public class PartitionGroupTest { @@ -475,7 +477,7 @@ public class PartitionGroupTest {
group.addRawRecords(partition2, list2);
assertEquals(list1.size() + list2.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds());
group.nextRecord(new RecordInfo(), time.milliseconds());
// shrink list of queues
group.updatePartitions(mkSet(createPartition2()), p -> {
@ -487,7 +489,7 @@ public class PartitionGroupTest { @@ -487,7 +489,7 @@ public class PartitionGroupTest {
assertEquals(list2.size(), group.numBuffered());
assertEquals(1, group.streamTime());
assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
assertThat(group.partitionTimestamp(partition2), equalTo(2L));
}
@ -508,7 +510,7 @@ public class PartitionGroupTest { @@ -508,7 +510,7 @@ public class PartitionGroupTest {
assertEquals(list1.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds());
group.nextRecord(new RecordInfo(), time.milliseconds());
// expand list of queues
group.updatePartitions(mkSet(createPartition1(), createPartition2()), p -> {
@ -521,7 +523,7 @@ public class PartitionGroupTest { @@ -521,7 +523,7 @@ public class PartitionGroupTest {
assertEquals(1, group.streamTime());
assertThat(group.partitionTimestamp(partition1), equalTo(1L));
assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records
}
@Test
@ -540,7 +542,7 @@ public class PartitionGroupTest { @@ -540,7 +542,7 @@ public class PartitionGroupTest {
group.addRawRecords(partition1, list1);
assertEquals(list1.size(), group.numBuffered());
assertTrue(group.allPartitionsBufferedLocally());
group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds());
group.nextRecord(new RecordInfo(), time.milliseconds());
// expand and shrink list of queues
group.updatePartitions(mkSet(createPartition2()), p -> {
@ -553,7 +555,7 @@ public class PartitionGroupTest { @@ -553,7 +555,7 @@ public class PartitionGroupTest {
assertEquals(1, group.streamTime());
assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1));
assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN));
assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), nullValue()); // all available records removed
assertThat(group.nextRecord(new RecordInfo(), time.milliseconds()), nullValue()); // all available records removed
}
@Test

32
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java

@ -1841,7 +1841,9 @@ public class StreamTaskTest { @@ -1841,7 +1841,9 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext);
logContext,
false
);
task.initializeIfNeeded();
task.completeRestoration(noOpResetter -> { });
@ -2504,7 +2506,9 @@ public class StreamTaskTest { @@ -2504,7 +2506,9 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext)
logContext,
false
)
);
assertThat(exception.getMessage(), equalTo("Invalid topology: " +
@ -2700,7 +2704,8 @@ public class StreamTaskTest { @@ -2700,7 +2704,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2741,7 +2746,8 @@ public class StreamTaskTest { @@ -2741,7 +2746,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2774,7 +2780,8 @@ public class StreamTaskTest { @@ -2774,7 +2780,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2812,7 +2819,8 @@ public class StreamTaskTest { @@ -2812,7 +2819,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2852,7 +2860,8 @@ public class StreamTaskTest { @@ -2852,7 +2860,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2893,7 +2902,8 @@ public class StreamTaskTest { @@ -2893,7 +2902,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2932,7 +2942,8 @@ public class StreamTaskTest { @@ -2932,7 +2942,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}
@ -2966,7 +2977,8 @@ public class StreamTaskTest { @@ -2966,7 +2977,8 @@ public class StreamTaskTest {
stateManager,
recordCollector,
context,
logContext
logContext,
false
);
}

838
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java

File diff suppressed because it is too large Load Diff

201
streams/src/test/java/org/apache/kafka/streams/processor/internals/SynchronizedPartitionGroupTest.java

@ -0,0 +1,201 @@ @@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.internals.AbstractPartitionGroup.RecordInfo;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.Collections;
import java.util.Set;
import java.util.function.Function;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
public class SynchronizedPartitionGroupTest {
@Mock
private AbstractPartitionGroup wrapped;
private SynchronizedPartitionGroup synchronizedPartitionGroup;
private AutoCloseable closeable;
@BeforeEach
public void setUp() {
closeable = MockitoAnnotations.openMocks(this);
synchronizedPartitionGroup = new SynchronizedPartitionGroup(wrapped);
}
@AfterEach
public void tearDown() throws Exception {
closeable.close();
}
@Test
public void testReadyToProcess() {
final long wallClockTime = 0L;
when(wrapped.readyToProcess(wallClockTime)).thenReturn(true);
synchronizedPartitionGroup.readyToProcess(wallClockTime);
verify(wrapped, times(1)).readyToProcess(wallClockTime);
}
@Test
public void testUpdatePartitions() {
final Set<TopicPartition> inputPartitions = Collections.singleton(new TopicPartition("topic", 0));
@SuppressWarnings("unchecked") final Function<TopicPartition, RecordQueue> recordQueueCreator = (Function<TopicPartition, RecordQueue>) mock(Function.class);
synchronizedPartitionGroup.updatePartitions(inputPartitions, recordQueueCreator);
verify(wrapped, times(1)).updatePartitions(inputPartitions, recordQueueCreator);
}
@Test
public void testSetPartitionTime() {
final TopicPartition partition = new TopicPartition("topic", 0);
final long partitionTime = 12345678L;
synchronizedPartitionGroup.setPartitionTime(partition, partitionTime);
verify(wrapped, times(1)).setPartitionTime(partition, partitionTime);
}
@Test
public void testNextRecord() {
final RecordInfo info = mock(RecordInfo.class);
final long wallClockTime = 12345678L;
final StampedRecord stampedRecord = mock(StampedRecord.class);
when(wrapped.nextRecord(info, wallClockTime)).thenReturn(stampedRecord);
final StampedRecord result = synchronizedPartitionGroup.nextRecord(info, wallClockTime);
assertEquals(stampedRecord, result);
verify(wrapped, times(1)).nextRecord(info, wallClockTime);
}
@Test
public void testAddRawRecords() {
final TopicPartition partition = new TopicPartition("topic", 0);
@SuppressWarnings("unchecked") final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords = (Iterable<ConsumerRecord<byte[], byte[]>>) mock(Iterable.class);
when(wrapped.addRawRecords(partition, rawRecords)).thenReturn(1);
final int result = synchronizedPartitionGroup.addRawRecords(partition, rawRecords);
assertEquals(1, result);
verify(wrapped, times(1)).addRawRecords(partition, rawRecords);
}
@Test
public void testPartitionTimestamp() {
final TopicPartition partition = new TopicPartition("topic", 0);
final long timestamp = 12345678L;
when(wrapped.partitionTimestamp(partition)).thenReturn(timestamp);
final long result = synchronizedPartitionGroup.partitionTimestamp(partition);
assertEquals(timestamp, result);
verify(wrapped, times(1)).partitionTimestamp(partition);
}
@Test
public void testStreamTime() {
final long streamTime = 12345678L;
when(wrapped.streamTime()).thenReturn(streamTime);
final long result = synchronizedPartitionGroup.streamTime();
assertEquals(streamTime, result);
verify(wrapped, times(1)).streamTime();
}
@Test
public void testHeadRecordOffset() {
final TopicPartition partition = new TopicPartition("topic", 0);
final Long recordOffset = 0L;
when(wrapped.headRecordOffset(partition)).thenReturn(recordOffset);
final Long result = synchronizedPartitionGroup.headRecordOffset(partition);
assertEquals(recordOffset, result);
verify(wrapped, times(1)).headRecordOffset(partition);
}
@Test
public void testNumBuffered() {
final int numBuffered = 1;
when(wrapped.numBuffered()).thenReturn(numBuffered);
final int result = synchronizedPartitionGroup.numBuffered();
assertEquals(numBuffered, result);
verify(wrapped, times(1)).numBuffered();
}
@Test
public void testNumBufferedWithTopicPartition() {
final TopicPartition partition = new TopicPartition("topic", 0);
final int numBuffered = 1;
when(wrapped.numBuffered(partition)).thenReturn(numBuffered);
final int result = synchronizedPartitionGroup.numBuffered(partition);
assertEquals(numBuffered, result);
verify(wrapped, times(1)).numBuffered(partition);
}
@Test
public void testClear() {
synchronizedPartitionGroup.clear();
verify(wrapped, times(1)).clear();
}
@Test
public void testUpdateLags() {
synchronizedPartitionGroup.updateLags();
verify(wrapped, times(1)).updateLags();
}
@Test
public void testClose() {
synchronizedPartitionGroup.close();
verify(wrapped, times(1)).close();
}
@Test
public void testPartitions() {
final Set<TopicPartition> partitions = Collections.singleton(new TopicPartition("topic", 0));
when(wrapped.partitions()).thenReturn(partitions);
final Set<TopicPartition> result = synchronizedPartitionGroup.partitions();
assertEquals(partitions, result);
verify(wrapped, times(1)).partitions();
}
}

124
streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java

@ -26,6 +26,7 @@ import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; @@ -26,6 +26,7 @@ import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.KafkaFuture;
import org.apache.kafka.common.Metric;
import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.TopicPartition;
@ -48,6 +49,7 @@ import org.apache.kafka.streams.processor.TaskId; @@ -48,6 +49,7 @@ import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
import org.apache.kafka.streams.processor.internals.Task.State;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig;
import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@ -199,6 +201,7 @@ public class TaskManagerTest { @@ -199,6 +201,7 @@ public class TaskManagerTest {
@org.mockito.Mock
private Admin adminClient;
final StateUpdater stateUpdater = Mockito.mock(StateUpdater.class);
final DefaultTaskManager schedulingTaskManager = Mockito.mock(DefaultTaskManager.class);
private TaskManager taskManager;
private TopologyMetadata topologyMetadata;
@ -212,16 +215,21 @@ public class TaskManagerTest { @@ -212,16 +215,21 @@ public class TaskManagerTest {
@Before
public void setUp() {
taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, false);
taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, null, false);
}
private TaskManager setUpTaskManager(final ProcessingMode processingMode, final boolean stateUpdaterEnabled) {
return setUpTaskManager(processingMode, null, stateUpdaterEnabled);
return setUpTaskManager(processingMode, null, stateUpdaterEnabled, false);
}
private TaskManager setUpTaskManager(final ProcessingMode processingMode, final TasksRegistry tasks, final boolean stateUpdaterEnabled) {
return setUpTaskManager(processingMode, tasks, stateUpdaterEnabled, false);
}
private TaskManager setUpTaskManager(final ProcessingMode processingMode,
final TasksRegistry tasks,
final boolean stateUpdaterEnabled) {
final boolean stateUpdaterEnabled,
final boolean processingThreadsEnabled) {
topologyMetadata = new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode));
final TaskManager taskManager = new TaskManager(
time,
@ -234,7 +242,8 @@ public class TaskManagerTest { @@ -234,7 +242,8 @@ public class TaskManagerTest {
topologyMetadata,
adminClient,
stateDirectory,
stateUpdaterEnabled ? stateUpdater : null
stateUpdaterEnabled ? stateUpdater : null,
processingThreadsEnabled ? schedulingTaskManager : null
);
taskManager.setMainConsumer(consumer);
return taskManager;
@ -287,6 +296,103 @@ public class TaskManagerTest { @@ -287,6 +296,103 @@ public class TaskManagerTest {
Mockito.verify(standbyTask).resume();
}
@Test
public void shouldLockAllTasksOnCorruptionWithProcessingThreads() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.activeTaskIds()).thenReturn(mkSet(taskId00, taskId01));
when(tasks.task(taskId00)).thenReturn(activeTask1);
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
replay(consumer);
taskManager.handleCorruption(mkSet(taskId00));
verify(consumer);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockCommitableTasksOnCorruptionWithProcessingThreads() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.commit(mkSet(activeTask1, activeTask2));
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockActiveOnHandleAssignmentWithProcessingThreads() {
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTaskIds()).thenReturn(mkSet(taskId00, taskId01));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.handleAssignment(
mkMap(mkEntry(taskId00, taskId00Partitions)),
mkMap(mkEntry(taskId01, taskId01Partitions))
);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockAffectedTasksOnHandleRevocation() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTasks()).thenReturn(mkSet(activeTask1, activeTask2));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.handleRevocation(taskId01Partitions);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00, taskId01));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00, taskId01));
}
@Test
public void shouldLockTasksOnClose() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId00Partitions).build();
final StreamTask activeTask2 = statefulTask(taskId01, taskId01ChangelogPartitions)
.inState(State.RUNNING)
.withInputPartitions(taskId01Partitions).build();
final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
when(tasks.allTasks()).thenReturn(mkSet(activeTask1, activeTask2));
final KafkaFuture<Void> mockFuture = KafkaFuture.completedFuture(null);
when(schedulingTaskManager.lockTasks(any())).thenReturn(mockFuture);
taskManager.closeAndCleanUpTasks(mkSet(activeTask1), mkSet(), false);
Mockito.verify(schedulingTaskManager).lockTasks(mkSet(taskId00));
Mockito.verify(schedulingTaskManager).unlockTasks(mkSet(taskId00));
}
@Test
public void shouldResumePollingForPartitionsWithAvailableSpaceForAllActiveTasks() {
final StreamTask activeTask1 = statefulTask(taskId00, taskId00ChangelogPartitions)
@ -3553,6 +3659,16 @@ public class TaskManagerTest { @@ -3553,6 +3659,16 @@ public class TaskManagerTest {
Mockito.verify(failedStatefulTask).closeDirty();
}
@Test
public void shouldShutdownSchedulingTaskManager() {
final TasksRegistry tasks = mock(TasksRegistry.class);
final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true, true);
taskManager.shutdown(true);
Mockito.verify(schedulingTaskManager).shutdown(Duration.ofMillis(Long.MAX_VALUE));
}
@Test
public void shouldShutDownStateUpdaterAndAddRestoredTasksToTaskRegistry() {
final TasksRegistry tasks = mock(TasksRegistry.class);

5
streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java

@ -460,7 +460,10 @@ public class StreamThreadStateStoreProviderTest { @@ -460,7 +460,10 @@ public class StreamThreadStateStoreProviderTest {
new MockTime(),
stateManager,
recordCollector,
context, logContext);
context,
logContext,
false
);
}
private void mockThread(final boolean initialized) {

4
streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java

@ -522,7 +522,9 @@ public class TopologyTestDriver implements Closeable { @@ -522,7 +522,9 @@ public class TopologyTestDriver implements Closeable {
stateManager,
recordCollector,
context,
logContext);
logContext,
false
);
task.initializeIfNeeded();
task.completeRestoration(noOpResetter -> { });
task.processorContext().setRecordContext(null);

Loading…
Cancel
Save