diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 42da481a8c4..de91e801de7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -491,7 +491,7 @@ public class TaskManager { } final TaskId taskId = task.id(); if (activeTasksToCreate.containsKey(taskId)) { - handleReAssignedActiveTask(task, activeTasksToCreate.get(taskId)); + handleReassignedActiveTask(task, activeTasksToCreate.get(taskId)); activeTasksToCreate.remove(taskId); } else if (standbyTasksToCreate.containsKey(taskId)) { tasksToRecycle.put(task, standbyTasksToCreate.get(taskId)); @@ -502,22 +502,18 @@ public class TaskManager { } } - private void handleReAssignedActiveTask(final Task task, + private void handleReassignedActiveTask(final Task task, final Set inputPartitions) { if (tasks.updateActiveTaskInputPartitions(task, inputPartitions)) { task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id())); } if (task.state() == State.SUSPENDED) { + tasks.removeTask(task); task.resume(); - moveTaskFromTasksRegistryToStateUpdater(task); + stateUpdater.add(task); } } - private void moveTaskFromTasksRegistryToStateUpdater(final Task task) { - tasks.removeTask(task); - stateUpdater.add(task); - } - private void handleTasksInStateUpdater(final Map> activeTasksToCreate, final Map> standbyTasksToCreate) { for (final Task task : stateUpdater.getTasks()) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java index 894c2587578..a28fb22766d 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java @@ -20,6 +20,7 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.internals.PendingUpdateAction.Action; +import org.apache.kafka.streams.processor.internals.Task.State; import org.slf4j.Logger; import java.util.Collection; @@ -240,8 +241,8 @@ class Tasks implements TasksRegistry { public synchronized void removeTask(final Task taskToRemove) { final TaskId taskId = taskToRemove.id(); - if (taskToRemove.state() != Task.State.CLOSED) { - throw new IllegalStateException("Attempted to remove a task that is not closed: " + taskId); + if (taskToRemove.state() != Task.State.CLOSED && taskToRemove.state() != State.SUSPENDED) { + throw new IllegalStateException("Attempted to remove a task that is not closed or suspended: " + taskId); } if (taskToRemove.isActive()) { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 85ab9c68688..add26097a65 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -418,7 +418,7 @@ public class TaskManagerTest { } @Test - public void shouldKeepReAssignedActiveTaskInStateUpdater() { + public void shouldKeepReassignedActiveTaskInStateUpdater() { final StreamTask reassignedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) .inState(State.RESTORING) .withInputPartitions(taskId03Partitions).build(); @@ -436,21 +436,41 @@ public class TaskManagerTest { } @Test - public void shouldRemoveReAssignedRevokedActiveTaskInStateUpdaterFromPendingTaskToSuspend() { - final StreamTask reAssignedRevokedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) + public void shouldMoveReassignedSuspendedActiveTaskToStateUpdater() { + final StreamTask reassignedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.SUSPENDED) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); + when(tasks.allTasks()).thenReturn(mkSet(reassignedActiveTask)); + + taskManager.handleAssignment( + mkMap(mkEntry(reassignedActiveTask.id(), reassignedActiveTask.inputPartitions())), + Collections.emptyMap() + ); + + Mockito.verify(tasks).removeTask(reassignedActiveTask); + Mockito.verify(stateUpdater).add(reassignedActiveTask); + Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); + Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap()); + } + + @Test + public void shouldRemoveReassignedRevokedActiveTaskInStateUpdaterFromPendingTaskToSuspend() { + final StreamTask reassignedRevokedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) .inState(State.RESTORING) .withInputPartitions(taskId03Partitions).build(); final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); - when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedRevokedActiveTask)); + when(stateUpdater.getTasks()).thenReturn(mkSet(reassignedRevokedActiveTask)); taskManager.handleAssignment( - mkMap(mkEntry(reAssignedRevokedActiveTask.id(), reAssignedRevokedActiveTask.inputPartitions())), + mkMap(mkEntry(reassignedRevokedActiveTask.id(), reassignedRevokedActiveTask.inputPartitions())), Collections.emptyMap() ); Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); - Mockito.verify(tasks).removePendingActiveTaskToSuspend(reAssignedRevokedActiveTask.id()); + Mockito.verify(tasks).removePendingActiveTaskToSuspend(reassignedRevokedActiveTask.id()); Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap()); } @@ -477,17 +497,17 @@ public class TaskManagerTest { } @Test - public void shouldKeepReAssignedStandbyTaskInStateUpdater() { - final StandbyTask reAssignedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions) + public void shouldKeepReassignedStandbyTaskInStateUpdater() { + final StandbyTask reassignedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions) .inState(State.RUNNING) .withInputPartitions(taskId02Partitions).build(); final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); - when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedStandbyTask)); + when(stateUpdater.getTasks()).thenReturn(mkSet(reassignedStandbyTask)); taskManager.handleAssignment( Collections.emptyMap(), - mkMap(mkEntry(reAssignedStandbyTask.id(), reAssignedStandbyTask.inputPartitions())) + mkMap(mkEntry(reassignedStandbyTask.id(), reassignedStandbyTask.inputPartitions())) ); Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java index 23aee1937c0..c65756d41f4 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java @@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.State; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -31,10 +32,12 @@ import static org.apache.kafka.common.utils.Utils.mkSet; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class TasksTest { @@ -52,6 +55,20 @@ public class TasksTest { private final Tasks tasks = new Tasks(new LogContext()); + @Test + public void shouldCheckStateWhenRemoveTask() { + final StreamTask closedTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).inState(State.CLOSED).build(); + final StandbyTask suspendedTask = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1)).inState(State.SUSPENDED).build(); + final StreamTask runningTask = statelessTask(TASK_1_0).inState(State.RUNNING).build(); + + tasks.addActiveTasks(mkSet(closedTask, runningTask)); + tasks.addStandbyTasks(Collections.singletonList(suspendedTask)); + + assertDoesNotThrow(() -> tasks.removeTask(closedTask)); + assertDoesNotThrow(() -> tasks.removeTask(suspendedTask)); + assertThrows(IllegalStateException.class, () -> tasks.removeTask(runningTask)); + } + @Test public void shouldKeepAddedTasks() { final StreamTask statefulTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).build();