From e7e399b9409b42f82d7ce57b99a461c465e5849d Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Tue, 17 Oct 2023 14:32:41 +0200 Subject: [PATCH] MINOR: allow removing a suspended task from task registry. (#14555) When we get a suspended task re-assigned in the eager rebalance protocol, we have to add the task back to the state updater so that it has a chance to catch up with its change log. This was prevented by a check in Tasks, which disallows removing SUSPENDED tasks from the task registry. I couldn't find a reason why this must be an invariant of the task registry, so this weakens the check. The error happens in the integration between TaskRegistry and TaskManager. However, this change anyway adds unit tests to more closely specify the intended behavior of the two modules. Reviewers: Bruno Cadonna --- .../processor/internals/TaskManager.java | 12 ++---- .../streams/processor/internals/Tasks.java | 5 ++- .../processor/internals/TaskManagerTest.java | 40 ++++++++++++++----- .../processor/internals/TasksTest.java | 17 ++++++++ 4 files changed, 54 insertions(+), 20 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 42da481a8c4..de91e801de7 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -491,7 +491,7 @@ public class TaskManager { } final TaskId taskId = task.id(); if (activeTasksToCreate.containsKey(taskId)) { - handleReAssignedActiveTask(task, activeTasksToCreate.get(taskId)); + handleReassignedActiveTask(task, activeTasksToCreate.get(taskId)); activeTasksToCreate.remove(taskId); } else if (standbyTasksToCreate.containsKey(taskId)) { tasksToRecycle.put(task, standbyTasksToCreate.get(taskId)); @@ -502,22 +502,18 @@ public class TaskManager { } } - private void handleReAssignedActiveTask(final Task task, + private void handleReassignedActiveTask(final Task task, final Set inputPartitions) { if (tasks.updateActiveTaskInputPartitions(task, inputPartitions)) { task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id())); } if (task.state() == State.SUSPENDED) { + tasks.removeTask(task); task.resume(); - moveTaskFromTasksRegistryToStateUpdater(task); + stateUpdater.add(task); } } - private void moveTaskFromTasksRegistryToStateUpdater(final Task task) { - tasks.removeTask(task); - stateUpdater.add(task); - } - private void handleTasksInStateUpdater(final Map> activeTasksToCreate, final Map> standbyTasksToCreate) { for (final Task task : stateUpdater.getTasks()) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java index 894c2587578..a28fb22766d 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java @@ -20,6 +20,7 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.internals.PendingUpdateAction.Action; +import org.apache.kafka.streams.processor.internals.Task.State; import org.slf4j.Logger; import java.util.Collection; @@ -240,8 +241,8 @@ class Tasks implements TasksRegistry { public synchronized void removeTask(final Task taskToRemove) { final TaskId taskId = taskToRemove.id(); - if (taskToRemove.state() != Task.State.CLOSED) { - throw new IllegalStateException("Attempted to remove a task that is not closed: " + taskId); + if (taskToRemove.state() != Task.State.CLOSED && taskToRemove.state() != State.SUSPENDED) { + throw new IllegalStateException("Attempted to remove a task that is not closed or suspended: " + taskId); } if (taskToRemove.isActive()) { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 85ab9c68688..add26097a65 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -418,7 +418,7 @@ public class TaskManagerTest { } @Test - public void shouldKeepReAssignedActiveTaskInStateUpdater() { + public void shouldKeepReassignedActiveTaskInStateUpdater() { final StreamTask reassignedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) .inState(State.RESTORING) .withInputPartitions(taskId03Partitions).build(); @@ -436,21 +436,41 @@ public class TaskManagerTest { } @Test - public void shouldRemoveReAssignedRevokedActiveTaskInStateUpdaterFromPendingTaskToSuspend() { - final StreamTask reAssignedRevokedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) + public void shouldMoveReassignedSuspendedActiveTaskToStateUpdater() { + final StreamTask reassignedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.SUSPENDED) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); + when(tasks.allTasks()).thenReturn(mkSet(reassignedActiveTask)); + + taskManager.handleAssignment( + mkMap(mkEntry(reassignedActiveTask.id(), reassignedActiveTask.inputPartitions())), + Collections.emptyMap() + ); + + Mockito.verify(tasks).removeTask(reassignedActiveTask); + Mockito.verify(stateUpdater).add(reassignedActiveTask); + Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); + Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap()); + } + + @Test + public void shouldRemoveReassignedRevokedActiveTaskInStateUpdaterFromPendingTaskToSuspend() { + final StreamTask reassignedRevokedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions) .inState(State.RESTORING) .withInputPartitions(taskId03Partitions).build(); final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); - when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedRevokedActiveTask)); + when(stateUpdater.getTasks()).thenReturn(mkSet(reassignedRevokedActiveTask)); taskManager.handleAssignment( - mkMap(mkEntry(reAssignedRevokedActiveTask.id(), reAssignedRevokedActiveTask.inputPartitions())), + mkMap(mkEntry(reassignedRevokedActiveTask.id(), reassignedRevokedActiveTask.inputPartitions())), Collections.emptyMap() ); Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); - Mockito.verify(tasks).removePendingActiveTaskToSuspend(reAssignedRevokedActiveTask.id()); + Mockito.verify(tasks).removePendingActiveTaskToSuspend(reassignedRevokedActiveTask.id()); Mockito.verify(standbyTaskCreator).createTasks(Collections.emptyMap()); } @@ -477,17 +497,17 @@ public class TaskManagerTest { } @Test - public void shouldKeepReAssignedStandbyTaskInStateUpdater() { - final StandbyTask reAssignedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions) + public void shouldKeepReassignedStandbyTaskInStateUpdater() { + final StandbyTask reassignedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions) .inState(State.RUNNING) .withInputPartitions(taskId02Partitions).build(); final TasksRegistry tasks = Mockito.mock(TasksRegistry.class); final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true); - when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedStandbyTask)); + when(stateUpdater.getTasks()).thenReturn(mkSet(reassignedStandbyTask)); taskManager.handleAssignment( Collections.emptyMap(), - mkMap(mkEntry(reAssignedStandbyTask.id(), reAssignedStandbyTask.inputPartitions())) + mkMap(mkEntry(reassignedStandbyTask.id(), reassignedStandbyTask.inputPartitions())) ); Mockito.verify(activeTaskCreator).createTasks(consumer, Collections.emptyMap()); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java index 23aee1937c0..c65756d41f4 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java @@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.State; import org.junit.jupiter.api.Test; import java.util.Collections; @@ -31,10 +32,12 @@ import static org.apache.kafka.common.utils.Utils.mkSet; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask; import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class TasksTest { @@ -52,6 +55,20 @@ public class TasksTest { private final Tasks tasks = new Tasks(new LogContext()); + @Test + public void shouldCheckStateWhenRemoveTask() { + final StreamTask closedTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).inState(State.CLOSED).build(); + final StandbyTask suspendedTask = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1)).inState(State.SUSPENDED).build(); + final StreamTask runningTask = statelessTask(TASK_1_0).inState(State.RUNNING).build(); + + tasks.addActiveTasks(mkSet(closedTask, runningTask)); + tasks.addStandbyTasks(Collections.singletonList(suspendedTask)); + + assertDoesNotThrow(() -> tasks.removeTask(closedTask)); + assertDoesNotThrow(() -> tasks.removeTask(suspendedTask)); + assertThrows(IllegalStateException.class, () -> tasks.removeTask(runningTask)); + } + @Test public void shouldKeepAddedTasks() { final StreamTask statefulTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).build();