Browse Source

KAFKA-6145: KIP-441: Improve assignment balance (#8588)

Validate that the assignment is always balanced wrt:
* active assignment balance
* stateful assignment balance
* task-parallel balance

Reviewers: Bruno Cadonna <bruno@confluent.io>, A. Sophie Blee-Goldman <sophie@confluent.io>
pull/8667/head
John Roesler 5 years ago committed by GitHub
parent
commit
d62f6ebdfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      clients/src/main/java/org/apache/kafka/common/utils/Utils.java
  2. 84
      clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
  3. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
  4. 10
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
  5. 30
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalancedAssignor.java
  6. 245
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
  7. 79
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
  8. 86
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultBalancedAssignor.java
  9. 4
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
  10. 230
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
  11. 143
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RankedClient.java
  12. 4
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
  13. 2
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
  14. 139
      streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
  15. 330
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
  16. 100
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
  17. 111
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySetTest.java
  18. 295
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/DefaultBalancedAssignorTest.java
  19. 309
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
  20. 196
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RankedClientTest.java
  21. 164
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
  22. 304
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
  23. 126
      streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueueTest.java

17
clients/src/main/java/org/apache/kafka/common/utils/Utils.java

@ -1155,4 +1155,21 @@ public final class Utils { @@ -1155,4 +1155,21 @@ public final class Utils {
}
return result;
}
@SafeVarargs
public static <E> Set<E> intersection(final Supplier<Set<E>> constructor, final Set<E> first, final Set<E>... set) {
final Set<E> result = constructor.get();
result.addAll(first);
for (final Set<E> s : set) {
result.retainAll(s);
}
return result;
}
public static <E> Set<E> diff(final Supplier<Set<E>> constructor, final Set<E> left, final Set<E> right) {
final Set<E> result = constructor.get();
result.addAll(left);
result.removeAll(right);
return result;
}
}

84
clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java

@ -42,10 +42,13 @@ import java.util.stream.Collectors; @@ -42,10 +42,13 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static java.util.Arrays.asList;
import static java.util.Collections.emptySet;
import static org.apache.kafka.common.utils.Utils.diff;
import static org.apache.kafka.common.utils.Utils.formatAddress;
import static org.apache.kafka.common.utils.Utils.formatBytes;
import static org.apache.kafka.common.utils.Utils.getHost;
import static org.apache.kafka.common.utils.Utils.getPort;
import static org.apache.kafka.common.utils.Utils.intersection;
import static org.apache.kafka.common.utils.Utils.mkSet;
import static org.apache.kafka.common.utils.Utils.murmur2;
import static org.apache.kafka.common.utils.Utils.union;
@ -597,4 +600,85 @@ public class UtilsTest { @@ -597,4 +600,85 @@ public class UtilsTest {
assertThat(union, is(mkSet("a", "b", "c", "d", "e")));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testUnionOfOne() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> union = union(TreeSet::new, oneSet);
assertThat(union, is(mkSet("a", "b", "c")));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testUnionOfMany() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> twoSet = mkSet("c", "d", "e");
final Set<String> threeSet = mkSet("b", "c", "d");
final Set<String> fourSet = mkSet("x", "y", "z");
final Set<String> union = union(TreeSet::new, oneSet, twoSet, threeSet, fourSet);
assertThat(union, is(mkSet("a", "b", "c", "d", "e", "x", "y", "z")));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testUnionOfNone() {
final Set<String> union = union(TreeSet::new);
assertThat(union, is(emptySet()));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testIntersection() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> anotherSet = mkSet("c", "d", "e");
final Set<String> intersection = intersection(TreeSet::new, oneSet, anotherSet);
assertThat(intersection, is(mkSet("c")));
assertThat(intersection.getClass(), equalTo(TreeSet.class));
}
@Test
public void testIntersectionOfOne() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> intersection = intersection(TreeSet::new, oneSet);
assertThat(intersection, is(mkSet("a", "b", "c")));
assertThat(intersection.getClass(), equalTo(TreeSet.class));
}
@Test
public void testIntersectionOfMany() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> twoSet = mkSet("c", "d", "e");
final Set<String> threeSet = mkSet("b", "c", "d");
final Set<String> union = intersection(TreeSet::new, oneSet, twoSet, threeSet);
assertThat(union, is(mkSet("c")));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testDisjointIntersectionOfMany() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> twoSet = mkSet("c", "d", "e");
final Set<String> threeSet = mkSet("b", "c", "d");
final Set<String> fourSet = mkSet("x", "y", "z");
final Set<String> union = intersection(TreeSet::new, oneSet, twoSet, threeSet, fourSet);
assertThat(union, is(emptySet()));
assertThat(union.getClass(), equalTo(TreeSet.class));
}
@Test
public void testDiff() {
final Set<String> oneSet = mkSet("a", "b", "c");
final Set<String> anotherSet = mkSet("c", "d", "e");
final Set<String> diff = diff(TreeSet::new, oneSet, anotherSet);
assertThat(diff, is(mkSet("a", "b")));
assertThat(diff.getClass(), equalTo(TreeSet.class));
}
}

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java

@ -1073,7 +1073,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf @@ -1073,7 +1073,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
// until it has been revoked and can safely be reassigned according to the COOPERATIVE protocol
if (newPartitionForConsumer && allOwnedPartitions.contains(partition)) {
log.info("Removing task {} from assignment until it is safely revoked in followup rebalance", taskId);
clientState.removeFromAssignment(taskId);
clientState.unassignActive(taskId);
// Clear the assigned partitions list for this task if any partition can not safely be assigned,
// so as not to encode a partial task
assignedPartitionsForTask.clear();

10
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java

@ -364,5 +364,15 @@ public final class AssignorConfiguration { @@ -364,5 +364,15 @@ public final class AssignorConfiguration {
this.numStandbyReplicas = numStandbyReplicas;
this.probingRebalanceIntervalMs = probingRebalanceIntervalMs;
}
@Override
public String toString() {
return "AssignmentConfigs{" +
"\n acceptableRecoveryLag=" + acceptableRecoveryLag +
"\n maxWarmupReplicas=" + maxWarmupReplicas +
"\n numStandbyReplicas=" + numStandbyReplicas +
"\n probingRebalanceIntervalMs=" + probingRebalanceIntervalMs +
"\n}";
}
}
}

30
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalancedAssignor.java

@ -1,30 +0,0 @@ @@ -1,30 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import java.util.UUID;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import org.apache.kafka.streams.processor.TaskId;
public interface BalancedAssignor {
Map<UUID, List<TaskId>> assign(final SortedSet<UUID> clients,
final SortedSet<TaskId> tasks,
final Map<UUID, Integer> clientsToNumberOfStreamThreads);
}

245
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java

@ -23,27 +23,30 @@ import org.slf4j.Logger; @@ -23,27 +23,30 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.HashMap;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
import static java.util.Comparator.comparing;
import static org.apache.kafka.common.utils.Utils.union;
import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
public class ClientState {
private static final Logger LOG = LoggerFactory.getLogger(ClientState.class);
public static final Comparator<TopicPartition> TOPIC_PARTITION_COMPARATOR = comparing(TopicPartition::topic).thenComparing(TopicPartition::partition);
private final Set<TaskId> activeTasks;
private final Set<TaskId> standbyTasks;
private final Set<TaskId> assignedTasks;
private final Set<TaskId> prevActiveTasks;
private final Set<TaskId> prevStandbyTasks;
private final Set<TaskId> prevAssignedTasks;
private final Map<TopicPartition, String> ownedPartitions;
private final Map<TaskId, Long> taskOffsetSums; // contains only stateful tasks we previously owned
@ -56,34 +59,28 @@ public class ClientState { @@ -56,34 +59,28 @@ public class ClientState {
}
ClientState(final int capacity) {
this(new HashSet<>(),
new HashSet<>(),
new HashSet<>(),
new HashSet<>(),
new HashSet<>(),
new HashSet<>(),
new HashMap<>(),
new HashMap<>(),
new HashMap<>(),
capacity);
activeTasks = new TreeSet<>();
standbyTasks = new TreeSet<>();
prevActiveTasks = new TreeSet<>();
prevStandbyTasks = new TreeSet<>();
ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
taskOffsetSums = new TreeMap<>();
taskLagTotals = new TreeMap<>();
this.capacity = capacity;
}
private ClientState(final Set<TaskId> activeTasks,
final Set<TaskId> standbyTasks,
final Set<TaskId> assignedTasks,
final Set<TaskId> prevActiveTasks,
final Set<TaskId> prevStandbyTasks,
final Set<TaskId> prevAssignedTasks,
final Map<TopicPartition, String> ownedPartitions,
final SortedMap<TopicPartition, String> ownedPartitions,
final Map<TaskId, Long> taskOffsetSums,
final Map<TaskId, Long> taskLagTotals,
final int capacity) {
this.activeTasks = activeTasks;
this.standbyTasks = standbyTasks;
this.assignedTasks = assignedTasks;
this.prevActiveTasks = prevActiveTasks;
this.prevStandbyTasks = prevStandbyTasks;
this.prevAssignedTasks = prevAssignedTasks;
this.ownedPartitions = ownedPartitions;
this.taskOffsetSums = taskOffsetSums;
this.taskLagTotals = taskLagTotals;
@ -94,97 +91,149 @@ public class ClientState { @@ -94,97 +91,149 @@ public class ClientState {
final Set<TaskId> previousStandbyTasks,
final Map<TaskId, Long> taskLagTotals,
final int capacity) {
activeTasks = new HashSet<>();
standbyTasks = new HashSet<>();
assignedTasks = new HashSet<>();
prevActiveTasks = unmodifiableSet(new HashSet<>(previousActiveTasks));
prevStandbyTasks = unmodifiableSet(new HashSet<>(previousStandbyTasks));
prevAssignedTasks = unmodifiableSet(union(HashSet::new, previousActiveTasks, previousStandbyTasks));
ownedPartitions = emptyMap();
activeTasks = new TreeSet<>();
standbyTasks = new TreeSet<>();
prevActiveTasks = unmodifiableSet(new TreeSet<>(previousActiveTasks));
prevStandbyTasks = unmodifiableSet(new TreeSet<>(previousStandbyTasks));
ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
taskOffsetSums = emptyMap();
this.taskLagTotals = unmodifiableMap(taskLagTotals);
this.capacity = capacity;
}
public ClientState copy() {
final TreeMap<TopicPartition, String> newOwnedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
newOwnedPartitions.putAll(ownedPartitions);
return new ClientState(
new HashSet<>(activeTasks),
new HashSet<>(standbyTasks),
new HashSet<>(assignedTasks),
new HashSet<>(prevActiveTasks),
new HashSet<>(prevStandbyTasks),
new HashSet<>(prevAssignedTasks),
new HashMap<>(ownedPartitions),
new HashMap<>(taskOffsetSums),
new HashMap<>(taskLagTotals),
new TreeSet<>(activeTasks),
new TreeSet<>(standbyTasks),
new TreeSet<>(prevActiveTasks),
new TreeSet<>(prevStandbyTasks),
newOwnedPartitions,
new TreeMap<>(taskOffsetSums),
new TreeMap<>(taskLagTotals),
capacity);
}
void assignActive(final TaskId task) {
activeTasks.add(task);
assignedTasks.add(task);
int capacity() {
return capacity;
}
void assignStandby(final TaskId task) {
standbyTasks.add(task);
assignedTasks.add(task);
public void incrementCapacity() {
capacity++;
}
boolean reachedCapacity() {
return assignedTaskCount() >= capacity;
}
public Set<TaskId> activeTasks() {
return unmodifiableSet(activeTasks);
}
public int activeTaskCount() {
return activeTasks.size();
}
double activeTaskLoad() {
return ((double) activeTaskCount()) / capacity;
}
public void assignActiveTasks(final Collection<TaskId> tasks) {
activeTasks.addAll(tasks);
assignedTasks.addAll(tasks);
}
void assignStandbyTasks(final Collection<TaskId> tasks) {
standbyTasks.addAll(tasks);
assignedTasks.addAll(tasks);
void assignActive(final TaskId task) {
assertNotAssigned(task);
activeTasks.add(task);
}
public Set<TaskId> activeTasks() {
return activeTasks;
public void unassignActive(final TaskId task) {
if (!activeTasks.contains(task)) {
throw new IllegalArgumentException("Tried to unassign active task " + task + ", but it is not currently assigned: " + this);
}
activeTasks.remove(task);
}
public Set<TaskId> standbyTasks() {
return standbyTasks;
return unmodifiableSet(standbyTasks);
}
Set<TaskId> prevActiveTasks() {
return prevActiveTasks;
boolean hasStandbyTask(final TaskId taskId) {
return standbyTasks.contains(taskId);
}
Set<TaskId> prevStandbyTasks() {
return prevStandbyTasks;
int standbyTaskCount() {
return standbyTasks.size();
}
public Map<TopicPartition, String> ownedPartitions() {
return ownedPartitions;
void assignStandby(final TaskId task) {
assertNotAssigned(task);
standbyTasks.add(task);
}
void unassignStandby(final TaskId task) {
if (!standbyTasks.contains(task)) {
throw new IllegalArgumentException("Tried to unassign standby task " + task + ", but it is not currently assigned: " + this);
}
standbyTasks.remove(task);
}
Set<TaskId> assignedTasks() {
// Since we're copying it, it's not strictly necessary to make it unmodifiable also.
// I'm just trying to prevent subtle bugs if we write code that thinks it can update
// the assignment by updating the returned set.
return unmodifiableSet(
union(
() -> new HashSet<>(activeTasks.size() + standbyTasks.size()),
activeTasks,
standbyTasks
)
);
}
@SuppressWarnings("WeakerAccess")
public int assignedTaskCount() {
return assignedTasks.size();
return activeTaskCount() + standbyTaskCount();
}
public void incrementCapacity() {
capacity++;
double assignedTaskLoad() {
return ((double) assignedTaskCount()) / capacity;
}
public int activeTaskCount() {
return activeTasks.size();
boolean hasAssignedTask(final TaskId taskId) {
return activeTasks.contains(taskId) || standbyTasks.contains(taskId);
}
int standbyTaskCount() {
return standbyTasks.size();
Set<TaskId> prevActiveTasks() {
return unmodifiableSet(prevActiveTasks);
}
private void addPreviousActiveTask(final TaskId task) {
prevActiveTasks.add(task);
}
void addPreviousActiveTasks(final Set<TaskId> prevTasks) {
prevActiveTasks.addAll(prevTasks);
prevAssignedTasks.addAll(prevTasks);
}
Set<TaskId> prevStandbyTasks() {
return unmodifiableSet(prevStandbyTasks);
}
private void addPreviousStandbyTask(final TaskId task) {
prevStandbyTasks.add(task);
}
void addPreviousStandbyTasks(final Set<TaskId> standbyTasks) {
prevStandbyTasks.addAll(standbyTasks);
prevAssignedTasks.addAll(standbyTasks);
}
Set<TaskId> previousAssignedTasks() {
return union(() -> new HashSet<>(prevActiveTasks.size() + prevStandbyTasks.size()), prevActiveTasks, prevStandbyTasks);
}
public Map<TopicPartition, String> ownedPartitions() {
return unmodifiableMap(ownedPartitions);
}
public void addOwnedPartitions(final Collection<TopicPartition> ownedPartitions, final String consumer) {
@ -252,27 +301,6 @@ public class ClientState { @@ -252,27 +301,6 @@ public class ClientState {
}
}
public void removeFromAssignment(final TaskId task) {
activeTasks.remove(task);
assignedTasks.remove(task);
}
boolean reachedCapacity() {
return assignedTasks.size() >= capacity;
}
int capacity() {
return capacity;
}
double activeTaskLoad() {
return ((double) activeTaskCount()) / capacity;
}
double taskLoad() {
return ((double) assignedTaskCount()) / capacity;
}
boolean hasUnfulfilledQuota(final int tasksPerThread) {
return activeTasks.size() < capacity * tasksPerThread;
}
@ -298,8 +326,17 @@ public class ClientState { @@ -298,8 +326,17 @@ public class ClientState {
}
}
boolean hasAssignedTask(final TaskId taskId) {
return assignedTasks.contains(taskId);
@Override
public String toString() {
return "[activeTasks: (" + activeTasks +
") standbyTasks: (" + standbyTasks +
") prevActiveTasks: (" + prevActiveTasks +
") prevStandbyTasks: (" + prevStandbyTasks +
") prevOwnedPartitionsByConsumerId: (" + ownedPartitions.keySet() +
") changelogOffsetTotalsByTask: (" + taskOffsetSums.entrySet() +
") capacity: " + capacity +
" assigned: " + assignedTaskCount() +
"]";
}
private void initializePrevActiveTasksFromOwnedPartitions(final Map<TopicPartition, TaskId> taskForPartitionMap) {
@ -336,37 +373,9 @@ public class ClientState { @@ -336,37 +373,9 @@ public class ClientState {
}
}
private void addPreviousActiveTask(final TaskId task) {
prevActiveTasks.add(task);
prevAssignedTasks.add(task);
}
private void addPreviousStandbyTask(final TaskId task) {
prevStandbyTasks.add(task);
prevAssignedTasks.add(task);
}
@Override
public String toString() {
return "[activeTasks: (" + activeTasks +
") standbyTasks: (" + standbyTasks +
") assignedTasks: (" + assignedTasks +
") prevActiveTasks: (" + prevActiveTasks +
") prevStandbyTasks: (" + prevStandbyTasks +
") prevAssignedTasks: (" + prevAssignedTasks +
") prevOwnedPartitionsByConsumerId: (" + ownedPartitions.keySet() +
") changelogOffsetTotalsByTask: (" + taskOffsetSums.entrySet() +
") capacity: " + capacity +
"]";
}
// Visible for testing
Set<TaskId> assignedTasks() {
return assignedTasks;
}
Set<TaskId> previousAssignedTasks() {
return prevAssignedTasks;
private void assertNotAssigned(final TaskId task) {
if (standbyTasks.contains(task) || activeTasks.contains(task)) {
throw new IllegalArgumentException("Tried to assign task " + task + ", but it is already assigned: " + this);
}
}
}

79
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java → streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java

@ -16,77 +16,58 @@ @@ -16,77 +16,58 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.UUID;
import java.util.function.BiFunction;
import org.apache.kafka.streams.processor.TaskId;
import java.util.function.Function;
/**
* Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment
*/
class ValidClientsByTaskLoadQueue {
class ConstrainedPrioritySet {
private final PriorityQueue<UUID> clientsByTaskLoad;
private final BiFunction<UUID, TaskId, Boolean> validClientCriteria;
private final BiFunction<UUID, TaskId, Boolean> constraint;
private final Set<UUID> uniqueClients = new HashSet<>();
ValidClientsByTaskLoadQueue(final Map<UUID, ClientState> clientStates,
final BiFunction<UUID, TaskId, Boolean> validClientCriteria) {
this.validClientCriteria = validClientCriteria;
clientsByTaskLoad = new PriorityQueue<>(
(client, other) -> {
final double clientTaskLoad = clientStates.get(client).taskLoad();
final double otherTaskLoad = clientStates.get(other).taskLoad();
if (clientTaskLoad < otherTaskLoad) {
return -1;
} else if (clientTaskLoad > otherTaskLoad) {
return 1;
} else {
return client.compareTo(other);
}
});
ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint,
final Function<UUID, Double> weight) {
this.constraint = constraint;
clientsByTaskLoad = new PriorityQueue<>(Comparator.comparing(weight).thenComparing(clientId -> clientId));
}
/**
* @return the next least loaded client that satisfies the given criteria, or null if none do
*/
UUID poll(final TaskId task) {
final List<UUID> validClient = poll(task, 1);
return validClient.isEmpty() ? null : validClient.get(0);
}
/**
* @return the next N <= {@code numClientsPerTask} clients in the underlying priority queue that are valid candidates for the given task
*/
List<UUID> poll(final TaskId task, final int numClients) {
final List<UUID> nextLeastLoadedValidClients = new LinkedList<>();
UUID poll(final TaskId task, final Function<UUID, Boolean> extraConstraint) {
final Set<UUID> invalidPolledClients = new HashSet<>();
while (nextLeastLoadedValidClients.size() < numClients) {
UUID candidateClient;
while (true) {
candidateClient = pollNextClient();
if (candidateClient == null) {
offerAll(invalidPolledClients);
return nextLeastLoadedValidClients;
}
if (validClientCriteria.apply(candidateClient, task)) {
nextLeastLoadedValidClients.add(candidateClient);
break;
} else {
invalidPolledClients.add(candidateClient);
}
while (!clientsByTaskLoad.isEmpty()) {
final UUID candidateClient = pollNextClient();
if (constraint.apply(candidateClient, task) && extraConstraint.apply(candidateClient)) {
// then we found the lightest, valid client
offerAll(invalidPolledClients);
return candidateClient;
} else {
// remember this client and try again later
invalidPolledClients.add(candidateClient);
}
}
// we tried all the clients, and none met the constraint (or there are no clients)
offerAll(invalidPolledClients);
return nextLeastLoadedValidClients;
return null;
}
/**
* @return the next least loaded client that satisfies the given criteria, or null if none do
*/
UUID poll(final TaskId task) {
return poll(task, client -> true);
}
void offerAll(final Collection<UUID> clients) {
@ -105,7 +86,7 @@ class ValidClientsByTaskLoadQueue { @@ -105,7 +86,7 @@ class ValidClientsByTaskLoadQueue {
}
private UUID pollNextClient() {
final UUID client = clientsByTaskLoad.poll();
final UUID client = clientsByTaskLoad.remove();
uniqueClients.remove(client);
return client;
}

86
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultBalancedAssignor.java

@ -1,86 +0,0 @@ @@ -1,86 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import java.util.UUID;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import org.apache.kafka.streams.processor.TaskId;
public class DefaultBalancedAssignor implements BalancedAssignor {
@Override
public Map<UUID, List<TaskId>> assign(final SortedSet<UUID> clients,
final SortedSet<TaskId> tasks,
final Map<UUID, Integer> clientsToNumberOfStreamThreads) {
final Map<UUID, List<TaskId>> assignment = new HashMap<>();
clients.forEach(client -> assignment.put(client, new ArrayList<>()));
distributeTasksEvenlyOverClients(assignment, clients, tasks);
balanceTasksOverStreamThreads(assignment, clients, clientsToNumberOfStreamThreads);
return assignment;
}
private void distributeTasksEvenlyOverClients(final Map<UUID, List<TaskId>> assignment,
final SortedSet<UUID> clients,
final SortedSet<TaskId> tasks) {
final LinkedList<TaskId> tasksToAssign = new LinkedList<>(tasks);
while (!tasksToAssign.isEmpty()) {
for (final UUID client : clients) {
final TaskId task = tasksToAssign.poll();
if (task == null) {
break;
}
assignment.get(client).add(task);
}
}
}
private void balanceTasksOverStreamThreads(final Map<UUID, List<TaskId>> assignment,
final SortedSet<UUID> clients,
final Map<UUID, Integer> clientsToNumberOfStreamThreads) {
boolean stop = false;
while (!stop) {
stop = true;
for (final UUID sourceClient : clients) {
final List<TaskId> sourceTasks = assignment.get(sourceClient);
for (final UUID destinationClient : clients) {
if (sourceClient.equals(destinationClient)) {
continue;
}
final List<TaskId> destinationTasks = assignment.get(destinationClient);
final int assignedTasksPerStreamThreadAtDestination =
destinationTasks.size() / clientsToNumberOfStreamThreads.get(destinationClient);
final int assignedTasksPerStreamThreadAtSource =
sourceTasks.size() / clientsToNumberOfStreamThreads.get(sourceClient);
if (assignedTasksPerStreamThreadAtSource - assignedTasksPerStreamThreadAtDestination > 1) {
final Iterator<TaskId> sourceIterator = sourceTasks.iterator();
final TaskId taskToMove = sourceIterator.next();
sourceIterator.remove();
destinationTasks.add(taskToMove);
stop = false;
}
}
}
}
}
}

4
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java

@ -41,9 +41,9 @@ public class FallbackPriorTaskAssignor implements TaskAssignor { @@ -41,9 +41,9 @@ public class FallbackPriorTaskAssignor implements TaskAssignor {
@Override
public boolean assign(final Map<UUID, ClientState> clients,
final Set<TaskId> allTaskIds,
final Set<TaskId> standbyTaskIds,
final Set<TaskId> statefulTaskIds,
final AssignmentConfigs configs) {
delegate.assign(clients, allTaskIds, standbyTaskIds, configs);
delegate.assign(clients, allTaskIds, statefulTaskIds, configs);
return true;
}
}

230
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java

@ -17,130 +17,183 @@ @@ -17,130 +17,183 @@
package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.Task;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.buildClientRankingsByTask;
import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
import static org.apache.kafka.common.utils.Utils.diff;
import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
public class HighAvailabilityTaskAssignor implements TaskAssignor {
private static final Logger log = LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class);
private Map<UUID, ClientState> clientStates;
private Map<UUID, Integer> clientsToNumberOfThreads;
private SortedSet<UUID> sortedClients;
private Set<TaskId> allTasks;
private SortedSet<TaskId> statefulTasks;
private SortedSet<TaskId> statelessTasks;
private AssignmentConfigs configs;
private SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates;
private Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients;
@Override
public boolean assign(final Map<UUID, ClientState> clientStates,
final Set<TaskId> allTasks,
final Set<TaskId> statefulTasks,
public boolean assign(final Map<UUID, ClientState> clients,
final Set<TaskId> allTaskIds,
final Set<TaskId> statefulTaskIds,
final AssignmentConfigs configs) {
this.configs = configs;
this.clientStates = clientStates;
this.allTasks = allTasks;
this.statefulTasks = new TreeSet<>(statefulTasks);
statelessTasks = new TreeSet<>(allTasks);
statelessTasks.removeAll(statefulTasks);
sortedClients = new TreeSet<>();
clientsToNumberOfThreads = new HashMap<>();
clientStates.forEach((client, state) -> {
sortedClients.add(client);
clientsToNumberOfThreads.put(client, state.capacity());
});
final SortedSet<TaskId> statefulTasks = new TreeSet<>(statefulTaskIds);
final TreeMap<UUID, ClientState> clientStates = new TreeMap<>(clients);
statefulTasksToRankedCandidates =
buildClientRankingsByTask(statefulTasks, clientStates, configs.acceptableRecoveryLag);
tasksToCaughtUpClients = tasksToCaughtUpClients(statefulTasksToRankedCandidates);
assignActiveStatefulTasks(clientStates, statefulTasks);
assignStandbyReplicaTasks(
clientStates,
statefulTasks,
configs.numStandbyReplicas
);
final Map<TaskId, Integer> tasksToRemainingStandbys =
statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> configs.numStandbyReplicas));
final boolean probingRebalanceNeeded = assignStatefulActiveTasks(tasksToRemainingStandbys);
assignStandbyReplicaTasks(tasksToRemainingStandbys);
final boolean probingRebalanceNeeded = assignTaskMovements(
tasksToCaughtUpClients(statefulTasks, clientStates, configs.acceptableRecoveryLag),
clientStates,
configs.maxWarmupReplicas
);
assignStatelessActiveTasks();
assignStatelessActiveTasks(clientStates, diff(TreeSet::new, allTaskIds, statefulTasks));
log.info("Decided on assignment: " +
clientStates +
" with " +
(probingRebalanceNeeded ? "" : "no") +
" followup probing rebalance.");
return probingRebalanceNeeded;
}
private boolean assignStatefulActiveTasks(final Map<TaskId, Integer> tasksToRemainingStandbys) {
final Map<UUID, List<TaskId>> statefulActiveTaskAssignment = new DefaultBalancedAssignor().assign(
sortedClients,
statefulTasks,
clientsToNumberOfThreads
);
private static void assignActiveStatefulTasks(final SortedMap<UUID, ClientState> clientStates,
final SortedSet<TaskId> statefulTasks) {
Iterator<ClientState> clientStateIterator = null;
for (final TaskId task : statefulTasks) {
if (clientStateIterator == null || !clientStateIterator.hasNext()) {
clientStateIterator = clientStates.values().iterator();
}
clientStateIterator.next().assignActive(task);
}
return assignTaskMovements(
statefulActiveTaskAssignment,
tasksToCaughtUpClients,
balanceTasksOverThreads(
clientStates,
tasksToRemainingStandbys,
configs.maxWarmupReplicas
ClientState::activeTasks,
ClientState::unassignActive,
ClientState::assignActive
);
}
private void assignStandbyReplicaTasks(final Map<TaskId, Integer> tasksToRemainingStandbys) {
final ValidClientsByTaskLoadQueue standbyTaskClientsByTaskLoad = new ValidClientsByTaskLoadQueue(
clientStates,
(client, task) -> !clientStates.get(client).assignedTasks().contains(task)
private static void assignStandbyReplicaTasks(final TreeMap<UUID, ClientState> clientStates,
final Set<TaskId> statefulTasks,
final int numStandbyReplicas) {
final Map<TaskId, Integer> tasksToRemainingStandbys =
statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbyReplicas));
final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = new ConstrainedPrioritySet(
(client, task) -> !clientStates.get(client).hasAssignedTask(task),
client -> clientStates.get(client).assignedTaskLoad()
);
standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet());
for (final TaskId task : statefulTasksToRankedCandidates.keySet()) {
final int numRemainingStandbys = tasksToRemainingStandbys.get(task);
final List<UUID> clients = standbyTaskClientsByTaskLoad.poll(task, numRemainingStandbys);
for (final UUID client : clients) {
for (final TaskId task : statefulTasks) {
int numRemainingStandbys = tasksToRemainingStandbys.get(task);
while (numRemainingStandbys > 0) {
final UUID client = standbyTaskClientsByTaskLoad.poll(task);
if (client == null) {
break;
}
clientStates.get(client).assignStandby(task);
numRemainingStandbys--;
standbyTaskClientsByTaskLoad.offer(client);
}
standbyTaskClientsByTaskLoad.offerAll(clients);
final int numStandbysAssigned = clients.size();
if (numStandbysAssigned < numRemainingStandbys) {
if (numRemainingStandbys > 0) {
log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
"There is not enough available capacity. You should " +
"increase the number of threads and/or application instances " +
"to maintain the requested number of standby replicas.",
numRemainingStandbys - numStandbysAssigned, configs.numStandbyReplicas, task);
numRemainingStandbys, numStandbyReplicas, task);
}
}
}
private void assignStatelessActiveTasks() {
final ValidClientsByTaskLoadQueue statelessActiveTaskClientsByTaskLoad = new ValidClientsByTaskLoadQueue(
balanceTasksOverThreads(
clientStates,
(client, task) -> true
ClientState::standbyTasks,
ClientState::unassignStandby,
ClientState::assignStandby
);
}
private static void balanceTasksOverThreads(final SortedMap<UUID, ClientState> clientStates,
final Function<ClientState, Set<TaskId>> currentAssignmentAccessor,
final BiConsumer<ClientState, TaskId> taskUnassignor,
final BiConsumer<ClientState, TaskId> taskAssignor) {
boolean keepBalancing = true;
while (keepBalancing) {
keepBalancing = false;
for (final Map.Entry<UUID, ClientState> sourceEntry : clientStates.entrySet()) {
final UUID sourceClient = sourceEntry.getKey();
final ClientState sourceClientState = sourceEntry.getValue();
for (final Map.Entry<UUID, ClientState> destinationEntry : clientStates.entrySet()) {
final UUID destinationClient = destinationEntry.getKey();
final ClientState destinationClientState = destinationEntry.getValue();
if (sourceClient.equals(destinationClient)) {
continue;
}
final Set<TaskId> sourceTasks = new TreeSet<>(currentAssignmentAccessor.apply(sourceClientState));
final Iterator<TaskId> sourceIterator = sourceTasks.iterator();
while (shouldMoveATask(sourceClientState, destinationClientState) && sourceIterator.hasNext()) {
final TaskId taskToMove = sourceIterator.next();
final boolean canMove = !destinationClientState.hasAssignedTask(taskToMove);
if (canMove) {
taskUnassignor.accept(sourceClientState, taskToMove);
taskAssignor.accept(destinationClientState, taskToMove);
keepBalancing = true;
}
}
}
}
}
}
private static boolean shouldMoveATask(final ClientState sourceClientState,
final ClientState destinationClientState) {
final double skew = sourceClientState.assignedTaskLoad() - destinationClientState.assignedTaskLoad();
if (skew <= 0) {
return false;
}
final double proposedAssignedTasksPerStreamThreadAtDestination =
(destinationClientState.assignedTaskCount() + 1.0) / destinationClientState.capacity();
final double proposedAssignedTasksPerStreamThreadAtSource =
(sourceClientState.assignedTaskCount() - 1.0) / sourceClientState.capacity();
final double proposedSkew = proposedAssignedTasksPerStreamThreadAtSource - proposedAssignedTasksPerStreamThreadAtDestination;
if (proposedSkew < 0) {
// then the move would only create an imbalance in the other direction.
return false;
}
// we should only move a task if doing so would actually improve the skew.
return proposedSkew < skew;
}
private static void assignStatelessActiveTasks(final TreeMap<UUID, ClientState> clientStates,
final Iterable<TaskId> statelessTasks) {
final ConstrainedPrioritySet statelessActiveTaskClientsByTaskLoad = new ConstrainedPrioritySet(
(client, task) -> true,
client -> clientStates.get(client).activeTaskLoad()
);
statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet());
@ -152,27 +205,22 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor { @@ -152,27 +205,22 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
}
}
/**
* Compute the balance factor as the difference in stateful active task count per thread between the most and
* least loaded clients
*/
static int computeBalanceFactor(final Collection<ClientState> clientStates,
final Set<TaskId> statefulTasks) {
int minActiveStatefulTasksPerThreadCount = Integer.MAX_VALUE;
int maxActiveStatefulTasksPerThreadCount = 0;
for (final ClientState state : clientStates) {
final Set<TaskId> activeTasks = new HashSet<>(state.prevActiveTasks());
activeTasks.retainAll(statefulTasks);
final int taskPerThreadCount = activeTasks.size() / state.capacity();
if (taskPerThreadCount < minActiveStatefulTasksPerThreadCount) {
minActiveStatefulTasksPerThreadCount = taskPerThreadCount;
}
if (taskPerThreadCount > maxActiveStatefulTasksPerThreadCount) {
maxActiveStatefulTasksPerThreadCount = taskPerThreadCount;
private static Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients(final Set<TaskId> statefulTasks,
final Map<UUID, ClientState> clientStates,
final long acceptableRecoveryLag) {
final Map<TaskId, SortedSet<UUID>> taskToCaughtUpClients = new HashMap<>();
for (final TaskId task : statefulTasks) {
for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
final UUID client = clientEntry.getKey();
final long taskLag = clientEntry.getValue().lagFor(task);
if (taskLag == Task.LATEST_OFFSET || taskLag <= acceptableRecoveryLag) {
taskToCaughtUpClients.computeIfAbsent(task, ignored -> new TreeSet<>()).add(client);
}
}
}
return maxActiveStatefulTasksPerThreadCount - minActiveStatefulTasksPerThreadCount;
return taskToCaughtUpClients;
}
}

143
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RankedClient.java

@ -1,143 +0,0 @@ @@ -1,143 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.Task;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RankedClient implements Comparable<RankedClient> {
private static final Logger log = LoggerFactory.getLogger(RankedClient.class);
private final UUID clientId;
private final long rank;
RankedClient(final UUID clientId, final long rank) {
this.clientId = clientId;
this.rank = rank;
}
UUID clientId() {
return clientId;
}
long rank() {
return rank;
}
@Override
public int compareTo(final RankedClient clientIdAndLag) {
if (rank < clientIdAndLag.rank) {
return -1;
} else if (rank > clientIdAndLag.rank) {
return 1;
} else {
return clientId.compareTo(clientIdAndLag.clientId);
}
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final RankedClient that = (RankedClient) o;
return rank == that.rank && Objects.equals(clientId, that.clientId);
}
@Override
public int hashCode() {
return Objects.hash(clientId, rank);
}
/**
* Maps tasks to clients with caught-up states for the task.
*
* @param statefulTasksToRankedClients ranked clients map
* @return map from tasks with caught-up clients to the list of client candidates
*/
static Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients(final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients) {
final Map<TaskId, SortedSet<UUID>> taskToCaughtUpClients = new HashMap<>();
for (final SortedMap.Entry<TaskId, SortedSet<RankedClient>> taskToRankedClients : statefulTasksToRankedClients.entrySet()) {
final SortedSet<RankedClient> rankedClients = taskToRankedClients.getValue();
for (final RankedClient rankedClient : rankedClients) {
if (rankedClient.rank() == Task.LATEST_OFFSET || rankedClient.rank() == 0) {
final TaskId taskId = taskToRankedClients.getKey();
taskToCaughtUpClients.computeIfAbsent(taskId, ignored -> new TreeSet<>()).add(rankedClient.clientId());
} else {
break;
}
}
}
return taskToCaughtUpClients;
}
/**
* Rankings are computed as follows, with lower being more caught up:
* Rank -1: active running task
* Rank 0: standby or restoring task whose overall lag is within the acceptableRecoveryLag bounds
* Rank 1: tasks whose lag is unknown, eg because it was not encoded in an older version subscription.
* Since it may have been caught-up, we rank it higher than clients whom we know are not caught-up
* to give it priority without classifying it as caught-up and risking violating high availability
* Rank 1+: all other tasks are ranked according to their actual total lag
* @return Sorted set of all client candidates for each stateful task, ranked by their overall lag. Tasks are
*/
static SortedMap<TaskId, SortedSet<RankedClient>> buildClientRankingsByTask(final Set<TaskId> statefulTasks,
final Map<UUID, ClientState> clientStates,
final long acceptableRecoveryLag) {
final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates = new TreeMap<>();
for (final TaskId task : statefulTasks) {
final SortedSet<RankedClient> rankedClientCandidates = new TreeSet<>();
statefulTasksToRankedCandidates.put(task, rankedClientCandidates);
for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
final UUID clientId = clientEntry.getKey();
final long taskLag = clientEntry.getValue().lagFor(task);
final long clientRank;
if (taskLag == Task.LATEST_OFFSET) {
clientRank = Task.LATEST_OFFSET;
} else if (taskLag == UNKNOWN_OFFSET_SUM) {
clientRank = 1L;
} else if (taskLag <= acceptableRecoveryLag) {
clientRank = 0L;
} else {
clientRank = taskLag;
}
rankedClientCandidates.add(new RankedClient(clientId, clientRank));
}
}
log.trace("Computed statefulTasksToRankedCandidates map as {}", statefulTasksToRankedCandidates);
return statefulTasksToRankedCandidates;
}
}

4
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java

@ -56,11 +56,11 @@ public class StickyTaskAssignor implements TaskAssignor { @@ -56,11 +56,11 @@ public class StickyTaskAssignor implements TaskAssignor {
@Override
public boolean assign(final Map<UUID, ClientState> clients,
final Set<TaskId> allTaskIds,
final Set<TaskId> standbyTaskIds,
final Set<TaskId> statefulTaskIds,
final AssignmentConfigs configs) {
this.clients = clients;
this.allTaskIds = allTaskIds;
this.standbyTaskIds = standbyTaskIds;
this.standbyTaskIds = statefulTaskIds;
final int maxPairs = allTaskIds.size() * (allTaskIds.size() - 1) / 2;
taskPairs = new TaskPairs(maxPairs);

2
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java

@ -28,6 +28,6 @@ public interface TaskAssignor { @@ -28,6 +28,6 @@ public interface TaskAssignor {
*/
boolean assign(Map<UUID, ClientState> clients,
Set<TaskId> allTaskIds,
Set<TaskId> standbyTaskIds,
Set<TaskId> statefulTaskIds,
AssignorConfiguration.AssignmentConfigs configs);
}

139
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java

@ -16,16 +16,22 @@ @@ -16,16 +16,22 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;
import java.util.List;
import org.apache.kafka.streams.processor.TaskId;
import java.util.Comparator;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.kafka.streams.processor.TaskId;
import java.util.function.BiFunction;
class TaskMovement {
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
final class TaskMovement {
private final TaskId task;
private final UUID destination;
private final SortedSet<UUID> caughtUpClients;
@ -40,6 +46,14 @@ class TaskMovement { @@ -40,6 +46,14 @@ class TaskMovement {
}
}
private TaskId task() {
return task;
}
private int numCaughtUpClients() {
return caughtUpClients.size();
}
/**
* @return true if this client is caught-up for this task, or the task has no caught-up clients
*/
@ -53,75 +67,94 @@ class TaskMovement { @@ -53,75 +67,94 @@ class TaskMovement {
/**
* @return whether any warmup replicas were assigned
*/
static boolean assignTaskMovements(final Map<UUID, List<TaskId>> statefulActiveTaskAssignment,
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
static boolean assignTaskMovements(final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
final Map<UUID, ClientState> clientStates,
final Map<TaskId, Integer> tasksToRemainingStandbys,
final int maxWarmupReplicas) {
boolean warmupReplicasAssigned = false;
final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
(client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients);
final ValidClientsByTaskLoadQueue clientsByTaskLoad = new ValidClientsByTaskLoadQueue(
clientStates,
(client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients)
final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new ConstrainedPrioritySet(
caughtUpPredicate,
client -> clientStates.get(client).assignedTaskLoad()
);
final SortedSet<TaskMovement> taskMovements = new TreeSet<>(
(movement, other) -> {
final int numCaughtUpClients = movement.caughtUpClients.size();
final int otherNumCaughtUpClients = other.caughtUpClients.size();
if (numCaughtUpClients != otherNumCaughtUpClients) {
return Integer.compare(numCaughtUpClients, otherNumCaughtUpClients);
} else {
return movement.task.compareTo(other.task);
}
}
final Queue<TaskMovement> taskMovements = new PriorityQueue<>(
Comparator.comparing(TaskMovement::numCaughtUpClients).thenComparing(TaskMovement::task)
);
for (final Map.Entry<UUID, List<TaskId>> assignmentEntry : statefulActiveTaskAssignment.entrySet()) {
final UUID client = assignmentEntry.getKey();
final ClientState state = clientStates.get(client);
for (final TaskId task : assignmentEntry.getValue()) {
if (taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients)) {
state.assignActive(task);
} else {
final TaskMovement taskMovement = new TaskMovement(task, client, tasksToCaughtUpClients.get(task));
taskMovements.add(taskMovement);
for (final Map.Entry<UUID, ClientState> clientStateEntry : clientStates.entrySet()) {
final UUID client = clientStateEntry.getKey();
final ClientState state = clientStateEntry.getValue();
for (final TaskId task : state.activeTasks()) {
// if the desired client is not caught up, and there is another client that _is_ caught up, then
// we schedule a movement, so we can move the active task to the caught-up client. We'll try to
// assign a warm-up to the desired client so that we can move it later on.
if (!taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients)) {
taskMovements.add(new TaskMovement(task, client, tasksToCaughtUpClients.get(task)));
}
}
clientsByTaskLoad.offer(client);
caughtUpClientsByTaskLoad.offer(client);
}
final boolean movementsNeeded = !taskMovements.isEmpty();
final AtomicInteger remainingWarmupReplicas = new AtomicInteger(maxWarmupReplicas);
for (final TaskMovement movement : taskMovements) {
final UUID sourceClient = clientsByTaskLoad.poll(movement.task);
if (sourceClient == null) {
throw new IllegalStateException("Tried to move task to caught-up client but none exist");
}
final ClientState sourceClientState = clientStates.get(sourceClient);
sourceClientState.assignActive(movement.task);
clientsByTaskLoad.offer(sourceClient);
final UUID standbySourceClient = caughtUpClientsByTaskLoad.poll(
movement.task,
c -> clientStates.get(c).hasStandbyTask(movement.task)
);
if (standbySourceClient == null) {
// there's not a caught-up standby available to take over the task, so we'll schedule a warmup instead
final UUID sourceClient = requireNonNull(
caughtUpClientsByTaskLoad.poll(movement.task),
"Tried to move task to caught-up client but none exist"
);
final ClientState destinationClientState = clientStates.get(movement.destination);
if (shouldAssignWarmupReplica(movement.task, destinationClientState, remainingWarmupReplicas, tasksToRemainingStandbys)) {
destinationClientState.assignStandby(movement.task);
clientsByTaskLoad.offer(movement.destination);
warmupReplicasAssigned = true;
moveActiveAndTryToWarmUp(
remainingWarmupReplicas,
movement.task,
clientStates.get(sourceClient),
clientStates.get(movement.destination)
);
caughtUpClientsByTaskLoad.offerAll(asList(sourceClient, movement.destination));
} else {
// we found a candidate to trade standby/active state with our destination, so we don't need a warmup
swapStandbyAndActive(
movement.task,
clientStates.get(standbySourceClient),
clientStates.get(movement.destination)
);
caughtUpClientsByTaskLoad.offerAll(asList(standbySourceClient, movement.destination));
}
}
return warmupReplicasAssigned;
return movementsNeeded;
}
private static boolean shouldAssignWarmupReplica(final TaskId task,
final ClientState destinationClientState,
final AtomicInteger remainingWarmupReplicas,
final Map<TaskId, Integer> tasksToRemainingStandbys) {
if (destinationClientState.previousAssignedTasks().contains(task) && tasksToRemainingStandbys.get(task) > 0) {
tasksToRemainingStandbys.compute(task, (t, numStandbys) -> numStandbys - 1);
return true;
private static void moveActiveAndTryToWarmUp(final AtomicInteger remainingWarmupReplicas,
final TaskId task,
final ClientState sourceClientState,
final ClientState destinationClientState) {
sourceClientState.assignActive(task);
if (remainingWarmupReplicas.getAndDecrement() > 0) {
destinationClientState.unassignActive(task);
destinationClientState.assignStandby(task);
} else {
return remainingWarmupReplicas.getAndDecrement() > 0;
// we have no more standbys or warmups to hand out, so we have to try and move it
// to the destination in a follow-on rebalance
destinationClientState.unassignActive(task);
}
}
private static void swapStandbyAndActive(final TaskId task,
final ClientState sourceClientState,
final ClientState destinationClientState) {
sourceClientState.unassignStandby(task);
sourceClientState.assignActive(task);
destinationClientState.unassignActive(task);
destinationClientState.assignStandby(task);
}
}

330
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java

@ -16,18 +16,35 @@ @@ -16,18 +16,35 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.TaskId;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.TaskId;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.apache.kafka.common.utils.Utils.entriesToMap;
import static org.apache.kafka.common.utils.Utils.intersection;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.fail;
public class AssignmentTestUtils {
public final class AssignmentTestUtils {
public static final UUID UUID_1 = uuidForInt(1);
public static final UUID UUID_2 = uuidForInt(2);
@ -57,6 +74,8 @@ public class AssignmentTestUtils { @@ -57,6 +74,8 @@ public class AssignmentTestUtils {
public static final Map<TaskId, Long> EMPTY_TASK_OFFSET_SUMS = emptyMap();
public static final Map<TopicPartition, Long> EMPTY_CHANGELOG_END_OFFSETS = new HashMap<>();
private AssignmentTestUtils() {}
static Map<UUID, ClientState> getClientStatesMap(final ClientState... states) {
final Map<UUID, ClientState> clientStates = new HashMap<>();
@ -75,4 +94,303 @@ public class AssignmentTestUtils { @@ -75,4 +94,303 @@ public class AssignmentTestUtils {
static UUID uuidForInt(final int n) {
return new UUID(0, n);
}
static void assertValidAssignment(final int numStandbyReplicas,
final Set<TaskId> statefulTasks,
final Set<TaskId> statelessTasks,
final Map<UUID, ClientState> assignedStates,
final StringBuilder failureContext) {
assertValidAssignment(
numStandbyReplicas,
0,
statefulTasks,
statelessTasks,
assignedStates,
failureContext
);
}
static void assertValidAssignment(final int numStandbyReplicas,
final int maxWarmupReplicas,
final Set<TaskId> statefulTasks,
final Set<TaskId> statelessTasks,
final Map<UUID, ClientState> assignedStates,
final StringBuilder failureContext) {
final Map<TaskId, Set<UUID>> assignments = new TreeMap<>();
for (final TaskId taskId : statefulTasks) {
assignments.put(taskId, new TreeSet<>());
}
for (final TaskId taskId : statelessTasks) {
assignments.put(taskId, new TreeSet<>());
}
for (final Map.Entry<UUID, ClientState> entry : assignedStates.entrySet()) {
validateAndAddActiveAssignments(statefulTasks, statelessTasks, failureContext, assignments, entry);
validateAndAddStandbyAssignments(statefulTasks, statelessTasks, failureContext, assignments, entry);
}
final AtomicInteger remainingWarmups = new AtomicInteger(maxWarmupReplicas);
final TreeMap<TaskId, Set<UUID>> misassigned =
assignments
.entrySet()
.stream()
.filter(entry -> {
final int expectedActives = 1;
final boolean isStateless = statelessTasks.contains(entry.getKey());
final int expectedStandbys = isStateless ? 0 : numStandbyReplicas;
// We'll never assign even the expected number of standbys if they don't actually fit in the cluster
final int expectedAssignments = Math.min(
assignedStates.size(),
expectedActives + expectedStandbys
);
final int actualAssignments = entry.getValue().size();
if (actualAssignments == expectedAssignments) {
return false; // not misassigned
} else {
if (actualAssignments == expectedAssignments + 1 && remainingWarmups.get() > 0) {
remainingWarmups.getAndDecrement();
return false; // it's a warmup, so it's fine
} else {
return true; // misassigned
}
}
})
.collect(entriesToMap(TreeMap::new));
if (!misassigned.isEmpty()) {
assertThat(
new StringBuilder().append("Found some over- or under-assigned tasks in the final assignment with ")
.append(numStandbyReplicas)
.append(" and max warmups ")
.append(maxWarmupReplicas)
.append(" standby replicas, stateful tasks:")
.append(statefulTasks)
.append(", and stateless tasks:")
.append(statelessTasks)
.append(failureContext)
.toString(),
misassigned,
is(emptyMap()));
}
}
private static void validateAndAddStandbyAssignments(final Set<TaskId> statefulTasks,
final Set<TaskId> statelessTasks,
final StringBuilder failureContext,
final Map<TaskId, Set<UUID>> assignments,
final Map.Entry<UUID, ClientState> entry) {
for (final TaskId standbyTask : entry.getValue().standbyTasks()) {
if (statelessTasks.contains(standbyTask)) {
throw new AssertionError(
new StringBuilder().append("Found a standby task for stateless task ")
.append(standbyTask)
.append(" on client ")
.append(entry)
.append(" stateless tasks:")
.append(statelessTasks)
.append(failureContext)
.toString()
);
} else if (assignments.containsKey(standbyTask)) {
assignments.get(standbyTask).add(entry.getKey());
} else {
throw new AssertionError(
new StringBuilder().append("Found an extra standby task ")
.append(standbyTask)
.append(" on client ")
.append(entry)
.append(" but expected stateful tasks:")
.append(statefulTasks)
.append(failureContext)
.toString()
);
}
}
}
private static void validateAndAddActiveAssignments(final Set<TaskId> statefulTasks,
final Set<TaskId> statelessTasks,
final StringBuilder failureContext,
final Map<TaskId, Set<UUID>> assignments,
final Map.Entry<UUID, ClientState> entry) {
for (final TaskId activeTask : entry.getValue().activeTasks()) {
if (assignments.containsKey(activeTask)) {
assignments.get(activeTask).add(entry.getKey());
} else {
throw new AssertionError(
new StringBuilder().append("Found an extra active task ")
.append(activeTask)
.append(" on client ")
.append(entry)
.append(" but expected stateful tasks:")
.append(statefulTasks)
.append(" and stateless tasks:")
.append(statelessTasks)
.append(failureContext)
.toString()
);
}
}
}
static void assertBalancedStatefulAssignment(final Set<TaskId> allStatefulTasks,
final Map<UUID, ClientState> clientStates,
final StringBuilder failureContext) {
double maxStateful = Double.MIN_VALUE;
double minStateful = Double.MAX_VALUE;
for (final ClientState clientState : clientStates.values()) {
final Set<TaskId> statefulTasks =
intersection(HashSet::new, clientState.assignedTasks(), allStatefulTasks);
final double statefulTaskLoad = 1.0 * statefulTasks.size() / clientState.capacity();
maxStateful = Math.max(maxStateful, statefulTaskLoad);
minStateful = Math.min(minStateful, statefulTaskLoad);
}
final double statefulDiff = maxStateful - minStateful;
if (statefulDiff > 1.0) {
final StringBuilder builder = new StringBuilder()
.append("detected a stateful assignment balance factor violation: ")
.append(statefulDiff)
.append(">")
.append(1.0)
.append(" in: ");
appendClientStates(builder, clientStates);
fail(builder.append(failureContext).toString());
}
}
static void assertBalancedActiveAssignment(final Map<UUID, ClientState> clientStates,
final StringBuilder failureContext) {
double maxActive = Double.MIN_VALUE;
double minActive = Double.MAX_VALUE;
for (final ClientState clientState : clientStates.values()) {
final double activeTaskLoad = clientState.activeTaskLoad();
maxActive = Math.max(maxActive, activeTaskLoad);
minActive = Math.min(minActive, activeTaskLoad);
}
final double activeDiff = maxActive - minActive;
if (activeDiff > 1.0) {
final StringBuilder builder = new StringBuilder()
.append("detected an active assignment balance factor violation: ")
.append(activeDiff)
.append(">")
.append(1.0)
.append(" in: ");
appendClientStates(builder, clientStates);
fail(builder.append(failureContext).toString());
}
}
static void assertBalancedTasks(final Map<UUID, ClientState> clientStates) {
final TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates);
if (taskSkewReport.totalSkewedTasks() > 0) {
fail("Expected a balanced task assignment, but was: " + taskSkewReport);
}
}
static TaskSkewReport analyzeTaskAssignmentBalance(final Map<UUID, ClientState> clientStates) {
final Function<Integer, Map<UUID, AtomicInteger>> initialClientCounts =
i -> clientStates.keySet().stream().collect(Collectors.toMap(c -> c, c -> new AtomicInteger(0)));
final Map<Integer, Map<UUID, AtomicInteger>> subtopologyToClientsWithPartition = new TreeMap<>();
for (final Map.Entry<UUID, ClientState> entry : clientStates.entrySet()) {
final UUID client = entry.getKey();
final ClientState clientState = entry.getValue();
for (final TaskId task : clientState.activeTasks()) {
final int subtopology = task.topicGroupId;
subtopologyToClientsWithPartition
.computeIfAbsent(subtopology, initialClientCounts)
.get(client)
.incrementAndGet();
}
}
int maxTaskSkew = 0;
final Set<Integer> skewedSubtopologies = new TreeSet<>();
for (final Map.Entry<Integer, Map<UUID, AtomicInteger>> entry : subtopologyToClientsWithPartition.entrySet()) {
final Map<UUID, AtomicInteger> clientsWithPartition = entry.getValue();
int max = Integer.MIN_VALUE;
int min = Integer.MAX_VALUE;
for (final AtomicInteger count : clientsWithPartition.values()) {
max = Math.max(max, count.get());
min = Math.min(min, count.get());
}
final int taskSkew = max - min;
maxTaskSkew = Math.max(maxTaskSkew, taskSkew);
if (taskSkew > 1) {
skewedSubtopologies.add(entry.getKey());
}
}
return new TaskSkewReport(maxTaskSkew, skewedSubtopologies, subtopologyToClientsWithPartition);
}
static Matcher<ClientState> hasActiveTasks(final int taskCount) {
return hasProperty("activeTasks", ClientState::activeTaskCount, taskCount);
}
static Matcher<ClientState> hasStandbyTasks(final int taskCount) {
return hasProperty("standbyTasks", ClientState::standbyTaskCount, taskCount);
}
static <V> Matcher<ClientState> hasProperty(final String propertyName,
final Function<ClientState, V> propertyExtractor,
final V propertyValue) {
return new BaseMatcher<ClientState>() {
@Override
public void describeTo(final Description description) {
description.appendText(propertyName).appendText(":").appendValue(propertyValue);
}
@Override
public boolean matches(final Object actual) {
if (actual instanceof ClientState) {
return Objects.equals(propertyExtractor.apply((ClientState) actual), propertyValue);
} else {
return false;
}
}
};
}
static void appendClientStates(final StringBuilder stringBuilder,
final Map<UUID, ClientState> clientStates) {
stringBuilder.append('{').append('\n');
for (final Map.Entry<UUID, ClientState> entry : clientStates.entrySet()) {
stringBuilder.append(" ").append(entry.getKey()).append(": ").append(entry.getValue()).append('\n');
}
stringBuilder.append('}').append('\n');
}
static final class TaskSkewReport {
private final int maxTaskSkew;
private final Set<Integer> skewedSubtopologies;
private final Map<Integer, Map<UUID, AtomicInteger>> subtopologyToClientsWithPartition;
private TaskSkewReport(final int maxTaskSkew,
final Set<Integer> skewedSubtopologies,
final Map<Integer, Map<UUID, AtomicInteger>> subtopologyToClientsWithPartition) {
this.maxTaskSkew = maxTaskSkew;
this.skewedSubtopologies = skewedSubtopologies;
this.subtopologyToClientsWithPartition = subtopologyToClientsWithPartition;
}
int totalSkewedTasks() {
return skewedSubtopologies.size();
}
Set<Integer> skewedSubtopologies() {
return skewedSubtopologies;
}
@Override
public String toString() {
return "TaskSkewReport{" +
"maxTaskSkew=" + maxTaskSkew +
", skewedSubtopologies=" + skewedSubtopologies +
", subtopologyToClientsWithPartition=" + subtopologyToClientsWithPartition +
'}';
}
}
}

100
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java

@ -23,6 +23,7 @@ import org.junit.Test; @@ -23,6 +23,7 @@ import org.junit.Test;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
@ -31,6 +32,8 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment @@ -31,6 +32,8 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks;
import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
@ -83,6 +86,103 @@ public class ClientStateTest { @@ -83,6 +86,103 @@ public class ClientStateTest {
assertTrue(client.reachedCapacity());
}
@Test
public void shouldRefuseDoubleActiveTask() {
final ClientState clientState = new ClientState(1);
clientState.assignActive(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.assignActive(TASK_0_0));
}
@Test
public void shouldRefuseActiveAndStandbyTask() {
final ClientState clientState = new ClientState(1);
clientState.assignActive(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.assignStandby(TASK_0_0));
}
@Test
public void shouldRefuseDoubleStandbyTask() {
final ClientState clientState = new ClientState(1);
clientState.assignStandby(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.assignStandby(TASK_0_0));
}
@Test
public void shouldRefuseStandbyAndActiveTask() {
final ClientState clientState = new ClientState(1);
clientState.assignStandby(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.assignActive(TASK_0_0));
}
@Test
public void shouldRefuseToUnassignNotAssignedActiveTask() {
final ClientState clientState = new ClientState(1);
assertThrows(IllegalArgumentException.class, () -> clientState.unassignActive(TASK_0_0));
}
@Test
public void shouldRefuseToUnassignNotAssignedStandbyTask() {
final ClientState clientState = new ClientState(1);
assertThrows(IllegalArgumentException.class, () -> clientState.unassignStandby(TASK_0_0));
}
@Test
public void shouldRefuseToUnassignActiveTaskAsStandby() {
final ClientState clientState = new ClientState(1);
clientState.assignActive(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.unassignStandby(TASK_0_0));
}
@Test
public void shouldRefuseToUnassignStandbyTaskAsActive() {
final ClientState clientState = new ClientState(1);
clientState.assignStandby(TASK_0_0);
assertThrows(IllegalArgumentException.class, () -> clientState.unassignActive(TASK_0_0));
}
@Test
public void shouldUnassignActiveTask() {
final ClientState clientState = new ClientState(1);
clientState.assignActive(TASK_0_0);
assertThat(clientState, hasActiveTasks(1));
clientState.unassignActive(TASK_0_0);
assertThat(clientState, hasActiveTasks(0));
}
@Test
public void shouldUnassignStandbyTask() {
final ClientState clientState = new ClientState(1);
clientState.assignStandby(TASK_0_0);
assertThat(clientState, hasStandbyTasks(1));
clientState.unassignStandby(TASK_0_0);
assertThat(clientState, hasStandbyTasks(0));
}
@Test
public void shouldNotModifyActiveView() {
final ClientState clientState = new ClientState(1);
final Set<TaskId> taskIds = clientState.activeTasks();
assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0));
assertThat(clientState, hasActiveTasks(0));
}
@Test
public void shouldNotModifyStandbyView() {
final ClientState clientState = new ClientState(1);
final Set<TaskId> taskIds = clientState.standbyTasks();
assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0));
assertThat(clientState, hasStandbyTasks(0));
}
@Test
public void shouldNotModifyAssignedView() {
final ClientState clientState = new ClientState(1);
final Set<TaskId> taskIds = clientState.assignedTasks();
assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0));
assertThat(clientState, hasActiveTasks(0));
assertThat(clientState, hasStandbyTasks(0));
}
@Test
public void shouldAddActiveTasksToBothAssignedAndActive() {
client.assignActive(TASK_0_1);

111
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySetTest.java

@ -0,0 +1,111 @@ @@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.Test;
import java.util.UUID;
import java.util.function.BiFunction;
import static java.util.Arrays.asList;
import static java.util.Collections.singleton;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
public class ConstrainedPrioritySetTest {
private static final TaskId DUMMY_TASK = new TaskId(0, 0);
private final BiFunction<UUID, TaskId, Boolean> alwaysTrue = (client, task) -> true;
private final BiFunction<UUID, TaskId, Boolean> alwaysFalse = (client, task) -> false;
@Test
public void shouldReturnOnlyClient() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysTrue, client -> 1.0);
queue.offerAll(singleton(UUID_1));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
@Test
public void shouldReturnNull() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysFalse, client -> 1.0);
queue.offerAll(singleton(UUID_1));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
@Test
public void shouldReturnLeastLoadedClient() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(
alwaysTrue,
client -> (client == UUID_1) ? 3.0 : (client == UUID_2) ? 2.0 : 1.0
);
queue.offerAll(asList(UUID_1, UUID_2, UUID_3));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_3));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
@Test
public void shouldNotRetainDuplicates() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysTrue, client -> 1.0);
queue.offerAll(singleton(UUID_1));
queue.offer(UUID_1);
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
@Test
public void shouldOnlyReturnValidClients() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(
(client, task) -> client.equals(UUID_1),
client -> 1.0
);
queue.offerAll(asList(UUID_1, UUID_2));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
@Test
public void shouldApplyPollFilter() {
final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(
alwaysTrue,
client -> 1.0
);
queue.offerAll(asList(UUID_1, UUID_2));
assertThat(queue.poll(DUMMY_TASK, client -> client.equals(UUID_1)), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK, client -> client.equals(UUID_1)), nullValue());
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2));
assertThat(queue.poll(DUMMY_TASK), nullValue());
}
}

295
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/DefaultBalancedAssignorTest.java

@ -1,295 +0,0 @@ @@ -1,295 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import java.util.UUID;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.apache.kafka.common.utils.Utils.mkSortedSet;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
public class DefaultBalancedAssignorTest {
private static final SortedSet<UUID> TWO_CLIENTS = new TreeSet<>(Arrays.asList(UUID_1, UUID_2));
private static final SortedSet<UUID> THREE_CLIENTS = new TreeSet<>(Arrays.asList(UUID_1, UUID_2, UUID_3));
@Test
public void shouldAssignTasksEvenlyOverClientsWhereNumberOfClientsIntegralDivisorOfNumberOfTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
threeClientsToNumberOfStreamThreads(1, 1, 1)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_1_0, TASK_2_0);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_1, TASK_2_1);
final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_2, TASK_1_2, TASK_2_2);
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverClientsWhereNumberOfClientsNotIntegralDivisorOfNumberOfTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
TWO_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
twoClientsToNumberOfStreamThreads(1, 1)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_0_2, TASK_1_1, TASK_2_0, TASK_2_2);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_0, TASK_1_2, TASK_2_1);
assertThat(
assignment,
is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2))
);
}
@Test
public void shouldAssignTasksEvenlyOverClientsWhereNumberOfStreamThreadsIntegralDivisorOfNumberOfTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
threeClientsToNumberOfStreamThreads(3, 3, 3)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_1_0, TASK_2_0);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_1, TASK_2_1);
final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_2, TASK_1_2, TASK_2_2);
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverClientsWhereNumberOfStreamThreadsNotIntegralDivisorOfNumberOfTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
threeClientsToNumberOfStreamThreads(2, 2, 2)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_1_0, TASK_2_0);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_1, TASK_2_1);
final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_2, TASK_1_2, TASK_2_2);
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverUnevenlyDistributedStreamThreads() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
threeClientsToNumberOfStreamThreads(1, 2, 3)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_1_0, TASK_2_0);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_1, TASK_2_1, TASK_0_0);
final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_2, TASK_1_2, TASK_2_2);
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverClientsWithLessClientsThanTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1
),
threeClientsToNumberOfStreamThreads(1, 1, 1)
);
final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_0);
final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
final List<TaskId> assignedTasksForClient3 = Collections.emptyList();
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverClientsAndStreamThreadsWithMoreStreamThreadsThanTasks() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
THREE_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
threeClientsToNumberOfStreamThreads(6, 6, 6)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_1_0, TASK_2_0);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_1, TASK_2_1);
final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_2, TASK_1_2, TASK_2_2);
assertThat(
assignment,
is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
);
}
@Test
public void shouldAssignTasksEvenlyOverStreamThreadsButBestEffortOverClients() {
final Map<UUID, List<TaskId>> assignment = new DefaultBalancedAssignor().assign(
TWO_CLIENTS,
mkSortedSet(
TASK_0_0,
TASK_0_1,
TASK_0_2,
TASK_1_0,
TASK_1_1,
TASK_1_2,
TASK_2_0,
TASK_2_1,
TASK_2_2
),
twoClientsToNumberOfStreamThreads(6, 2)
);
final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_0, TASK_0_2, TASK_1_1, TASK_2_0, TASK_2_2,
TASK_0_1);
final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_1_0, TASK_1_2, TASK_2_1);
assertThat(
assignment,
is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2))
);
}
private static Map<UUID, Integer> twoClientsToNumberOfStreamThreads(final int numberOfStreamThread1,
final int numberOfStreamThread2) {
return mkMap(
mkEntry(UUID_1, numberOfStreamThread1),
mkEntry(UUID_2, numberOfStreamThread2)
);
}
private static Map<UUID, Integer> threeClientsToNumberOfStreamThreads(final int numberOfStreamThread1,
final int numberOfStreamThread2,
final int numberOfStreamThread3) {
return mkMap(
mkEntry(UUID_1, numberOfStreamThread1),
mkEntry(UUID_2, numberOfStreamThread2),
mkEntry(UUID_3, numberOfStreamThread3)
);
}
private static Map<UUID, List<TaskId>> expectedAssignmentForThreeClients(final List<TaskId> assignedTasksForClient1,
final List<TaskId> assignedTasksForClient2,
final List<TaskId> assignedTasksForClient3) {
return mkMap(
mkEntry(UUID_1, assignedTasksForClient1),
mkEntry(UUID_2, assignedTasksForClient2),
mkEntry(UUID_3, assignedTasksForClient3)
);
}
private static Map<UUID, List<TaskId>> expectedAssignmentForTwoClients(final List<TaskId> assignedTasksForClient1,
final List<TaskId> assignedTasksForClient2) {
return mkMap(
mkEntry(UUID_1, assignedTasksForClient1),
mkEntry(UUID_2, assignedTasksForClient2)
);
}
}

309
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java

@ -18,7 +18,6 @@ package org.apache.kafka.streams.processor.internals.assignment; @@ -18,7 +18,6 @@ package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
import org.easymock.EasyMock;
import org.junit.Test;
import java.util.HashMap;
@ -45,18 +44,23 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment @@ -45,18 +44,23 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.analyzeTaskAssignmentBalance;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedActiveAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedStatefulAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedTasks;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertValidAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
import static org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor.computeBalanceFactor;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.fail;
public class HighAvailabilityTaskAssignorTest {
private final AssignmentConfigs configWithoutStandbys = new AssignmentConfigs(
@ -73,6 +77,164 @@ public class HighAvailabilityTaskAssignorTest { @@ -73,6 +77,164 @@ public class HighAvailabilityTaskAssignorTest {
/*probingRebalanceIntervalMs*/ 60 * 1000L
);
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsIntegralDivisorOfNumberOfTasks() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfThreadsIntegralDivisorOfNumberOfTasks() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 3);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3);
final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsNotIntegralDivisorOfNumberOfTasks() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverUnevenlyDistributedStreamThreads() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 2);
final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertThat(clientState1, hasActiveTasks(1));
assertThat(clientState2, hasActiveTasks(2));
assertThat(clientState3, hasActiveTasks(3));
final AssignmentTestUtils.TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates);
if (taskSkewReport.totalSkewedTasks() == 0) {
fail("Expected a skewed task assignment, but was: " + taskSkewReport);
}
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverClientsWithLessClientsThanTasks() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverClientsAndStreamThreadsWithMoreStreamThreadsThanTasks() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 6);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 6);
final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 6);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
public void shouldAssignActiveStatefulTasksEvenlyOverStreamThreadsButBestEffortOverClients() {
final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 6);
final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3);
final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2);
final boolean unstable = new HighAvailabilityTaskAssignor().assign(
clientStates,
allTaskIds,
allTaskIds,
new AssignmentConfigs(0L, 0, 0, 0L)
);
assertThat(unstable, is(false));
assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder());
assertThat(clientState1, hasActiveTasks(6));
assertThat(clientState2, hasActiveTasks(3));
}
@Test
public void shouldComputeNewAssignmentIfThereAreUnassignedActiveTasks() {
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
@ -84,9 +246,14 @@ public class HighAvailabilityTaskAssignorTest { @@ -84,9 +246,14 @@ public class HighAvailabilityTaskAssignorTest {
singleton(TASK_0_0),
configWithoutStandbys);
assertThat(clientStates.get(UUID_1).activeTasks(), not(singleton(TASK_0_0)));
assertThat(clientStates.get(UUID_1).standbyTasks(), empty());
assertThat(probingRebalanceNeeded, is(false));
assertThat(client1, hasActiveTasks(2));
assertThat(client1, hasStandbyTasks(0));
assertValidAssignment(0, allTasks, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
@ -104,6 +271,10 @@ public class HighAvailabilityTaskAssignorTest { @@ -104,6 +271,10 @@ public class HighAvailabilityTaskAssignorTest {
assertThat(clientStates.get(UUID_2).standbyTasks(), not(empty()));
assertThat(probingRebalanceNeeded, is(false));
assertValidAssignment(1, allTasks, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
@ -125,89 +296,10 @@ public class HighAvailabilityTaskAssignorTest { @@ -125,89 +296,10 @@ public class HighAvailabilityTaskAssignorTest {
// we'll warm up task 0_0 on client1 because it's first in sorted order,
// although this isn't an optimal convergence
assertThat(probingRebalanceNeeded, is(true));
}
@Test
public void shouldComputeBalanceFactorAsDifferenceBetweenMostAndLeastLoadedClients() {
final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
final Set<ClientState> states = mkSet(client1, client2, client3);
final Set<TaskId> statefulTasks =
mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3);
expect(client1.capacity()).andReturn(1);
expect(client1.prevActiveTasks()).andReturn(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
expect(client2.capacity()).andReturn(1);
expect(client2.prevActiveTasks()).andReturn(mkSet(TASK_1_0, TASK_1_1));
expect(client3.capacity()).andReturn(1);
expect(client3.prevActiveTasks()).andReturn(mkSet(TASK_2_0, TASK_2_1, TASK_2_3));
replay(client1, client2, client3);
assertThat(computeBalanceFactor(states, statefulTasks), equalTo(2));
}
@Test
public void shouldComputeBalanceFactorWithDifferentClientCapacities() {
final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
final Set<ClientState> states = mkSet(client1, client2, client3);
final Set<TaskId> statefulTasks =
mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3);
// client 1: 4 tasks per thread
expect(client1.capacity()).andReturn(1);
expect(client1.prevActiveTasks()).andReturn(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
// client 2: 1 task per thread
expect(client2.capacity()).andReturn(2);
expect(client2.prevActiveTasks()).andReturn(mkSet(TASK_1_0, TASK_1_1));
// client 3: 1 task per thread
expect(client3.capacity()).andReturn(3);
expect(client3.prevActiveTasks()).andReturn(mkSet(TASK_2_0, TASK_2_1, TASK_2_3));
replay(client1, client2, client3);
assertThat(computeBalanceFactor(states, statefulTasks), equalTo(3));
}
@Test
public void shouldComputeBalanceFactorBasedOnStatefulTasksOnly() {
final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
final Set<ClientState> states = mkSet(client1, client2, client3);
// 0_0 and 0_1 are stateless
final Set<TaskId> statefulTasks = mkSet(TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3);
// client 1: 2 stateful tasks per thread
expect(client1.capacity()).andReturn(1);
expect(client1.prevActiveTasks()).andReturn(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
// client 2: 1 stateful task per thread
expect(client2.capacity()).andReturn(2);
expect(client2.prevActiveTasks()).andReturn(mkSet(TASK_1_0, TASK_1_1));
// client 3: 1 stateful task per thread
expect(client3.capacity()).andReturn(3);
expect(client3.prevActiveTasks()).andReturn(mkSet(TASK_2_0, TASK_2_1, TASK_2_3));
replay(client1, client2, client3);
assertThat(computeBalanceFactor(states, statefulTasks), equalTo(1));
}
@Test
public void shouldComputeBalanceFactorOfZeroWithOnlyOneClient() {
final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
expect(client1.capacity()).andReturn(1);
expect(client1.prevActiveTasks()).andReturn(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
replay(client1);
assertThat(computeBalanceFactor(singleton(client1), statefulTasks), equalTo(0));
assertValidAssignment(0, 1, allTasks, emptySet(), clientStates, new StringBuilder());
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder());
assertBalancedTasks(clientStates);
}
@Test
@ -366,11 +458,15 @@ public class HighAvailabilityTaskAssignorTest { @@ -366,11 +458,15 @@ public class HighAvailabilityTaskAssignorTest {
final boolean probingRebalanceNeeded =
new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
assertThat(client1.activeTaskCount(), equalTo(4));
assertThat(client2.standbyTaskCount(), equalTo(3)); // 1
assertThat(client3.standbyTaskCount(), equalTo(3));
assertHasNoStandbyTasks(client1);
assertHasNoActiveTasks(client2, client3);
assertValidAssignment(
1,
2,
statefulTasks,
emptySet(),
clientStates,
new StringBuilder()
);
assertThat(probingRebalanceNeeded, is(true));
}
@ -378,6 +474,7 @@ public class HighAvailabilityTaskAssignorTest { @@ -378,6 +474,7 @@ public class HighAvailabilityTaskAssignorTest {
public void shouldDistributeStatelessTasksToBalanceTotalTaskLoad() {
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2);
final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
final Set<TaskId> statelessTasks = mkSet(TASK_1_0, TASK_1_1, TASK_1_2);
final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks);
final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
@ -386,10 +483,22 @@ public class HighAvailabilityTaskAssignorTest { @@ -386,10 +483,22 @@ public class HighAvailabilityTaskAssignorTest {
final boolean probingRebalanceNeeded =
new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_2)));
assertHasNoStandbyTasks(client1);
assertThat(client2.activeTasks(), equalTo(mkSet(TASK_1_1)));
assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
assertValidAssignment(
1,
2,
statefulTasks,
statelessTasks,
clientStates,
new StringBuilder()
);
assertBalancedActiveAssignment(clientStates, new StringBuilder());
assertBalancedStatefulAssignment(statefulTasks, clientStates, new StringBuilder());
// since only client1 is caught up on the stateful tasks, we expect it to get _all_ the active tasks,
// which means that client2 should have gotten all of the stateless tasks, so the tasks should be skewed
final AssignmentTestUtils.TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates);
assertThat(taskSkewReport.toString(), taskSkewReport.skewedSubtopologies(), not(empty()));
assertThat(probingRebalanceNeeded, is(true));
}
@ -466,7 +575,7 @@ public class HighAvailabilityTaskAssignorTest { @@ -466,7 +575,7 @@ public class HighAvailabilityTaskAssignorTest {
private static void assertHasNoStandbyTasks(final ClientState... clients) {
for (final ClientState client : clients) {
assertThat(client.standbyTasks(), is(empty()));
assertThat(client, hasStandbyTasks(0));
}
}

196
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RankedClientTest.java

@ -1,196 +0,0 @@ @@ -1,196 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import static java.util.Collections.emptySet;
import static java.util.Collections.singleton;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.apache.kafka.common.utils.Utils.mkSortedSet;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_4;
import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.buildClientRankingsByTask;
import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
import static org.easymock.EasyMock.expect;
import static org.easymock.EasyMock.replay;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertTrue;
import java.util.Map;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.UUID;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.Task;
import org.easymock.EasyMock;
import org.junit.Test;
public class RankedClientTest {
private static final long ACCEPTABLE_RECOVERY_LAG = 100L;
private ClientState client1 = EasyMock.createNiceMock(ClientState.class);
private ClientState client2 = EasyMock.createNiceMock(ClientState.class);
private ClientState client3 = EasyMock.createNiceMock(ClientState.class);
@Test
public void shouldRankPreviousClientAboveEquallyCaughtUpClient() {
expect(client1.lagFor(TASK_0_0)).andReturn(Task.LATEST_OFFSET);
expect(client2.lagFor(TASK_0_0)).andReturn(0L);
replay(client1, client2);
final SortedSet<RankedClient> expectedClientRanking = mkSortedSet(
new RankedClient(UUID_1, Task.LATEST_OFFSET),
new RankedClient(UUID_2, 0L)
);
final Map<UUID, ClientState> states = mkMap(
mkEntry(UUID_1, client1),
mkEntry(UUID_2, client2)
);
final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
buildClientRankingsByTask(singleton(TASK_0_0), states, ACCEPTABLE_RECOVERY_LAG);
final SortedSet<RankedClient> clientRanking = statefulTasksToRankedCandidates.get(TASK_0_0);
EasyMock.verify(client1, client2);
assertThat(clientRanking, equalTo(expectedClientRanking));
}
@Test
public void shouldRankTaskWithUnknownOffsetSumBelowCaughtUpClientAndClientWithLargeLag() {
expect(client1.lagFor(TASK_0_0)).andReturn(UNKNOWN_OFFSET_SUM);
expect(client2.lagFor(TASK_0_0)).andReturn(50L);
expect(client3.lagFor(TASK_0_0)).andReturn(500L);
replay(client1, client2, client3);
final SortedSet<RankedClient> expectedClientRanking = mkSortedSet(
new RankedClient(UUID_2, 0L),
new RankedClient(UUID_1, 1L),
new RankedClient(UUID_3, 500L)
);
final Map<UUID, ClientState> states = mkMap(
mkEntry(UUID_1, client1),
mkEntry(UUID_2, client2),
mkEntry(UUID_3, client3)
);
final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
buildClientRankingsByTask(singleton(TASK_0_0), states, ACCEPTABLE_RECOVERY_LAG);
final SortedSet<RankedClient> clientRanking = statefulTasksToRankedCandidates.get(TASK_0_0);
EasyMock.verify(client1, client2, client3);
assertThat(clientRanking, equalTo(expectedClientRanking));
}
@Test
public void shouldRankAllClientsWithinAcceptableRecoveryLagWithRank0() {
expect(client1.lagFor(TASK_0_0)).andReturn(100L);
expect(client2.lagFor(TASK_0_0)).andReturn(0L);
replay(client1, client2);
final SortedSet<RankedClient> expectedClientRanking = mkSortedSet(
new RankedClient(UUID_1, 0L),
new RankedClient(UUID_2, 0L)
);
final Map<UUID, ClientState> states = mkMap(
mkEntry(UUID_1, client1),
mkEntry(UUID_2, client2)
);
final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
buildClientRankingsByTask(singleton(TASK_0_0), states, ACCEPTABLE_RECOVERY_LAG);
EasyMock.verify(client1, client2);
assertThat(statefulTasksToRankedCandidates.get(TASK_0_0), equalTo(expectedClientRanking));
}
@Test
public void shouldRankNotCaughtUpClientsAccordingToLag() {
expect(client1.lagFor(TASK_0_0)).andReturn(900L);
expect(client2.lagFor(TASK_0_0)).andReturn(800L);
expect(client3.lagFor(TASK_0_0)).andReturn(500L);
replay(client1, client2, client3);
final SortedSet<RankedClient> expectedClientRanking = mkSortedSet(
new RankedClient(UUID_3, 500L),
new RankedClient(UUID_2, 800L),
new RankedClient(UUID_1, 900L)
);
final Map<UUID, ClientState> states = mkMap(
mkEntry(UUID_1, client1),
mkEntry(UUID_2, client2),
mkEntry(UUID_3, client3)
);
final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
buildClientRankingsByTask(singleton(TASK_0_0), states, ACCEPTABLE_RECOVERY_LAG);
EasyMock.verify(client1, client2, client3);
assertThat(statefulTasksToRankedCandidates.get(TASK_0_0), equalTo(expectedClientRanking));
}
@Test
public void shouldReturnEmptyClientRankingsWithNoStatefulTasks() {
final Map<UUID, ClientState> states = mkMap(
mkEntry(UUID_1, client1),
mkEntry(UUID_2, client2)
);
assertTrue(buildClientRankingsByTask(emptySet(), states, ACCEPTABLE_RECOVERY_LAG).isEmpty());
}
@Test
public void shouldReturnTasksToCaughtUpClients() {
final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates = new TreeMap<>();
statefulTasksToRankedCandidates.put(
TASK_0_0,
mkSortedSet(
new RankedClient(UUID_1, Task.LATEST_OFFSET),
new RankedClient(UUID_2, 0L),
new RankedClient(UUID_3, 1L),
new RankedClient(UUID_4, 1000L)
));
final SortedSet<UUID> expectedCaughtUpClients = mkSortedSet(UUID_1, UUID_2);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = tasksToCaughtUpClients(statefulTasksToRankedCandidates);
assertThat(tasksToCaughtUpClients.get(TASK_0_0), equalTo(expectedCaughtUpClients));
}
@Test
public void shouldOnlyReturnTasksWithCaughtUpClients() {
final RankedClient rankedClient = new RankedClient(UUID_1, 1L);
final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates = new TreeMap<>();
statefulTasksToRankedCandidates.put(TASK_0_0, mkSortedSet(rankedClient));
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = tasksToCaughtUpClients(statefulTasksToRankedCandidates);
assertTrue(tasksToCaughtUpClients.isEmpty());
}
}

164
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java

@ -18,7 +18,6 @@ package org.apache.kafka.streams.processor.internals.assignment; @@ -18,7 +18,6 @@ package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
import org.hamcrest.MatcherAssert;
import org.junit.Test;
import java.util.Map;
@ -28,11 +27,13 @@ import java.util.Set; @@ -28,11 +27,13 @@ import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.Supplier;
import static java.util.Collections.emptyMap;
import static org.apache.kafka.common.utils.Utils.entriesToMap;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.appendClientStates;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedActiveAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedStatefulAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertValidAssignment;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.fail;
public class TaskAssignorConvergenceTest {
@ -45,36 +46,29 @@ public class TaskAssignorConvergenceTest { @@ -45,36 +46,29 @@ public class TaskAssignorConvergenceTest {
private static Harness initializeCluster(final int numStatelessTasks,
final int numStatefulTasks,
final int numNodes) {
final int numNodes,
final Supplier<Integer> partitionCountSupplier) {
int subtopology = 0;
final Set<TaskId> statelessTasks = new TreeSet<>();
{
int partition = 0;
for (int i = 0; i < numStatelessTasks; i++) {
statelessTasks.add(new TaskId(subtopology, partition));
if (partition == 4) {
subtopology++;
partition = 0;
} else {
partition++;
}
int remainingStatelessTasks = numStatelessTasks;
while (remainingStatelessTasks > 0) {
final int partitions = Math.min(remainingStatelessTasks, partitionCountSupplier.get());
for (int i = 0; i < partitions; i++) {
statelessTasks.add(new TaskId(subtopology, i));
remainingStatelessTasks--;
}
subtopology++;
}
final Map<TaskId, Long> statefulTaskEndOffsetSums = new TreeMap<>();
{
subtopology++;
int partition = 0;
for (int i = 0; i < numStatefulTasks; i++) {
statefulTaskEndOffsetSums.put(new TaskId(subtopology, partition), 150000L);
if (partition == 4) {
subtopology++;
partition = 0;
} else {
partition++;
}
int remainingStatefulTasks = numStatefulTasks;
while (remainingStatefulTasks > 0) {
final int partitions = Math.min(remainingStatefulTasks, partitionCountSupplier.get());
for (int i = 0; i < partitions; i++) {
statefulTaskEndOffsetSums.put(new TaskId(subtopology, i), 150000L);
remainingStatefulTasks--;
}
subtopology++;
}
final Map<UUID, ClientState> clientStates = new TreeMap<>();
@ -200,6 +194,12 @@ public class TaskAssignorConvergenceTest { @@ -200,6 +194,12 @@ public class TaskAssignorConvergenceTest {
clientStates.putAll(newClientStates);
}
private void recordConfig(final AssignmentConfigs configuration) {
history.append("Creating assignor with configuration: ")
.append(configuration)
.append('\n');
}
private void recordBefore(final int iteration) {
history.append("Starting Iteration: ").append(iteration).append('\n');
formatClientStates(false);
@ -213,20 +213,17 @@ public class TaskAssignorConvergenceTest { @@ -213,20 +213,17 @@ public class TaskAssignorConvergenceTest {
}
private void formatClientStates(final boolean printUnassigned) {
final Set<TaskId> unassignedTasks = new TreeSet<>();
unassignedTasks.addAll(statefulTaskEndOffsetSums.keySet());
unassignedTasks.addAll(statelessTasks);
history.append('{').append('\n');
for (final Map.Entry<UUID, ClientState> entry : clientStates.entrySet()) {
history.append(" ").append(entry.getKey()).append(": ").append(entry.getValue()).append('\n');
unassignedTasks.removeAll(entry.getValue().assignedTasks());
}
history.append('}').append('\n');
appendClientStates(history, clientStates);
if (printUnassigned) {
final Set<TaskId> unassignedTasks = new TreeSet<>();
unassignedTasks.addAll(statefulTaskEndOffsetSums.keySet());
unassignedTasks.addAll(statelessTasks);
for (final Map.Entry<UUID, ClientState> entry : clientStates.entrySet()) {
unassignedTasks.removeAll(entry.getValue().assignedTasks());
}
history.append("Unassigned Tasks: ").append(unassignedTasks).append('\n');
}
}
}
@Test
@ -236,16 +233,17 @@ public class TaskAssignorConvergenceTest { @@ -236,16 +233,17 @@ public class TaskAssignorConvergenceTest {
0,
1000L);
final Harness harness = Harness.initializeCluster(1, 1, 1);
final Harness harness = Harness.initializeCluster(1, 1, 1, () -> 1);
testForConvergence(harness, configs, 1);
verifyValidAssignment(0, harness);
verifyBalancedAssignment(harness);
}
@Test
public void assignmentShouldConvergeAfterAddingNode() {
final int numStatelessTasks = 15;
final int numStatefulTasks = 13;
final int numStatelessTasks = 7;
final int numStatefulTasks = 11;
final int maxWarmupReplicas = 2;
final int numStandbyReplicas = 0;
@ -254,18 +252,19 @@ public class TaskAssignorConvergenceTest { @@ -254,18 +252,19 @@ public class TaskAssignorConvergenceTest {
numStandbyReplicas,
1000L);
final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 1);
final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 1, () -> 5);
testForConvergence(harness, configs, 1);
harness.addNode();
// we expect convergence to involve moving each task at most once, and we can move "maxWarmupReplicas" number
// of tasks at once, hence the iteration limit
testForConvergence(harness, configs, numStatefulTasks / maxWarmupReplicas + 1);
verifyValidAssignment(numStandbyReplicas, harness);
verifyBalancedAssignment(harness);
}
@Test
public void droppingNodesShouldConverge() {
final int numStatelessTasks = 15;
final int numStatelessTasks = 11;
final int numStatefulTasks = 13;
final int maxWarmupReplicas = 2;
final int numStandbyReplicas = 0;
@ -275,7 +274,7 @@ public class TaskAssignorConvergenceTest { @@ -275,7 +274,7 @@ public class TaskAssignorConvergenceTest {
numStandbyReplicas,
1000L);
final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 7);
final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 7, () -> 5);
testForConvergence(harness, configs, 1);
harness.dropNode();
// This time, we allow one extra iteration because the
@ -283,6 +282,7 @@ public class TaskAssignorConvergenceTest { @@ -283,6 +282,7 @@ public class TaskAssignorConvergenceTest {
testForConvergence(harness, configs, numStatefulTasks / maxWarmupReplicas + 2);
verifyValidAssignment(numStandbyReplicas, harness);
verifyBalancedAssignment(harness);
}
@Test
@ -315,9 +315,15 @@ public class TaskAssignorConvergenceTest { @@ -315,9 +315,15 @@ public class TaskAssignorConvergenceTest {
numStandbyReplicas,
1000L);
harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, initialClusterSize);
harness = Harness.initializeCluster(
numStatelessTasks,
numStatefulTasks,
initialClusterSize,
() -> prng.nextInt(10) + 1
);
testForConvergence(harness, configs, 1);
verifyValidAssignment(numStandbyReplicas, harness);
verifyBalancedAssignment(harness);
for (int i = 0; i < numberOfEvents; i++) {
final int event = prng.nextInt(2);
@ -334,6 +340,7 @@ public class TaskAssignorConvergenceTest { @@ -334,6 +340,7 @@ public class TaskAssignorConvergenceTest {
if (!harness.clientStates.isEmpty()) {
testForConvergence(harness, configs, numStatefulTasks * 2);
verifyValidAssignment(numStandbyReplicas, harness);
verifyBalancedAssignment(harness);
}
}
} catch (final AssertionError t) {
@ -354,49 +361,32 @@ public class TaskAssignorConvergenceTest { @@ -354,49 +361,32 @@ public class TaskAssignorConvergenceTest {
}
}
private static void verifyValidAssignment(final int numStandbyReplicas, final Harness harness) {
final Map<TaskId, Set<UUID>> assignments = new TreeMap<>();
for (final TaskId taskId : harness.statefulTaskEndOffsetSums.keySet()) {
assignments.put(taskId, new TreeSet<>());
}
for (final TaskId taskId : harness.statelessTasks) {
assignments.put(taskId, new TreeSet<>());
}
for (final Map.Entry<UUID, ClientState> entry : harness.clientStates.entrySet()) {
for (final TaskId activeTask : entry.getValue().activeTasks()) {
if (assignments.containsKey(activeTask)) {
assignments.get(activeTask).add(entry.getKey());
}
}
for (final TaskId standbyTask : entry.getValue().standbyTasks()) {
assignments.get(standbyTask).add(entry.getKey());
}
private static void verifyBalancedAssignment(final Harness harness) {
final Set<TaskId> allStatefulTasks = harness.statefulTaskEndOffsetSums.keySet();
final Map<UUID, ClientState> clientStates = harness.clientStates;
final StringBuilder failureContext = harness.history;
assertBalancedActiveAssignment(clientStates, failureContext);
assertBalancedStatefulAssignment(allStatefulTasks, clientStates, failureContext);
final AssignmentTestUtils.TaskSkewReport taskSkewReport = AssignmentTestUtils.analyzeTaskAssignmentBalance(harness.clientStates);
if (taskSkewReport.totalSkewedTasks() > 0) {
fail(
new StringBuilder().append("Expected a balanced task assignment, but was: ")
.append(taskSkewReport)
.append('\n')
.append(failureContext)
.toString()
);
}
final TreeMap<TaskId, Set<UUID>> misassigned =
assignments
.entrySet()
.stream()
.filter(entry -> {
final int expectedActives = 1;
final boolean isStateless = harness.statelessTasks.contains(entry.getKey());
final int expectedStandbys = isStateless ? 0 : numStandbyReplicas;
// We'll never assign even the expected number of standbys if they don't actually fit in the cluster
final int expectedAssignments = Math.min(
harness.clientStates.size(),
expectedActives + expectedStandbys
);
return entry.getValue().size() != expectedAssignments;
})
.collect(entriesToMap(TreeMap::new));
MatcherAssert.assertThat(
new StringBuilder().append("Found some over- or under-assigned tasks in the final assignment with ")
.append(numStandbyReplicas)
.append(" standby replicas.")
.append(harness.history)
.toString(),
misassigned,
is(emptyMap()));
}
private static void verifyValidAssignment(final int numStandbyReplicas, final Harness harness) {
final Set<TaskId> statefulTasks = harness.statefulTaskEndOffsetSums.keySet();
final Set<TaskId> statelessTasks = harness.statelessTasks;
final Map<UUID, ClientState> assignedStates = harness.clientStates;
final StringBuilder failureContext = harness.history;
assertValidAssignment(numStandbyReplicas, statefulTasks, statelessTasks, assignedStates, failureContext);
}
private static void testForConvergence(final Harness harness,
@ -406,6 +396,8 @@ public class TaskAssignorConvergenceTest { @@ -406,6 +396,8 @@ public class TaskAssignorConvergenceTest {
allTasks.addAll(harness.statelessTasks);
allTasks.addAll(harness.statefulTaskEndOffsetSums.keySet());
harness.recordConfig(configs);
boolean rebalancePending = true;
int iteration = 0;
while (rebalancePending && iteration < iterationLimit) {

304
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java

@ -16,15 +16,24 @@ @@ -16,15 +16,24 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.Test;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.UUID;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.apache.kafka.common.utils.Utils.mkSet;
import static org.apache.kafka.common.utils.Utils.mkSortedSet;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASK_LIST;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
@ -35,262 +44,161 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment @@ -35,262 +44,161 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasProperty;
import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.Test;
import static org.hamcrest.Matchers.is;
public class TaskMovementTest {
private final ClientState client1 = new ClientState(1);
private final ClientState client2 = new ClientState(1);
private final ClientState client3 = new ClientState(1);
private final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
private final Map<UUID, List<TaskId>> emptyWarmupAssignment = mkMap(
mkEntry(UUID_1, EMPTY_TASK_LIST),
mkEntry(UUID_2, EMPTY_TASK_LIST),
mkEntry(UUID_3, EMPTY_TASK_LIST)
);
@Test
public void shouldAssignTasksToClientsAndReturnFalseWhenAllClientsCaughtUp() {
final int maxWarmupReplicas = Integer.MAX_VALUE;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
for (final TaskId task : allTasks) {
tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, UUID_3));
}
assertFalse(
final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
assertThat(
assignTaskMovements(
balancedAssignment,
tasksToCaughtUpClients,
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
getClientStatesMap(client1, client2, client3),
maxWarmupReplicas),
is(false)
);
verifyClientStateAssignments(balancedAssignment, emptyWarmupAssignment);
}
@Test
public void shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() {
final int maxWarmupReplicas = 2;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
final int maxWarmupReplicas = Integer.MAX_VALUE;
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
);
final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
assertFalse(
assertThat(
assignTaskMovements(
balancedAssignment,
emptyMap(),
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
getClientStatesMap(client1, client2, client3),
maxWarmupReplicas),
is(false)
);
verifyClientStateAssignments(balancedAssignment, emptyWarmupAssignment);
}
@Test
public void shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
final int maxWarmupReplicas = Integer.MAX_VALUE;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1));
final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2));
final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
mkEntry(TASK_0_0, mkSortedSet(UUID_1)),
mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
mkEntry(TASK_0_2, mkSortedSet(UUID_2))
);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_2)),
mkEntry(UUID_3, singletonList(TASK_0_1))
);
final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
mkEntry(UUID_1, EMPTY_TASK_LIST),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
);
assertTrue(
assertThat(
"should have assigned movements",
assignTaskMovements(
balancedAssignment,
tasksToCaughtUpClients,
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
);
verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
}
@Test
public void shouldProduceBalancedAndStateConstrainedAssignment() {
final int maxWarmupReplicas = Integer.MAX_VALUE;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_2, UUID_3)); // needs to be warmed up
tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_1, UUID_3)); // needs to be warmed up
tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2)); // needs to be warmed up
tasksToCaughtUpClients.put(TASK_1_1, mkSortedSet(UUID_1)); // needs to be warmed up
final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
mkEntry(UUID_1, asList(TASK_1_0, TASK_1_1)),
mkEntry(UUID_2, asList(TASK_0_2, TASK_0_0)),
mkEntry(UUID_3, asList(TASK_0_1, TASK_1_2))
);
final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
);
assertTrue(
assignTaskMovements(
balancedAssignment,
tasksToCaughtUpClients,
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
);
verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
maxWarmupReplicas),
is(true)
);
// The active tasks have changed to the ones that each client is caught up on
assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0)));
assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_2)));
assertThat(client3, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_1)));
// we assigned warmups to migrate to the input active assignment
assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_1)));
assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_2)));
}
@Test
public void shouldOnlyGetUpToMaxWarmupReplicasAndReturnTrue() {
final int maxWarmupReplicas = 1;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1));
final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2));
final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
mkEntry(TASK_0_0, mkSortedSet(UUID_1)),
mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
mkEntry(TASK_0_2, mkSortedSet(UUID_2))
);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_2)),
mkEntry(UUID_3, singletonList(TASK_0_1))
);
final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
mkEntry(UUID_1, EMPTY_TASK_LIST),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, EMPTY_TASK_LIST)
);
assertTrue(
assertThat(
"should have assigned movements",
assignTaskMovements(
balancedAssignment,
tasksToCaughtUpClients,
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
);
verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
tasksToCaughtUpClients,
clientStates,
maxWarmupReplicas),
is(true)
);
// The active tasks have changed to the ones that each client is caught up on
assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0)));
assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_2)));
assertThat(client3, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_1)));
// we should only assign one warmup, but it could be either one that needs to be migrated.
assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
try {
assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_1)));
assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
} catch (final AssertionError ignored) {
assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_2)));
}
}
@Test
public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() {
final int maxWarmupReplicas = 1;
final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
final int maxWarmupReplicas = 0;
final ClientState client1 = getClientStateWithActiveAssignment(emptyList());
client1.assignStandby(TASK_0_0);
final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
mkEntry(TASK_0_0, mkSortedSet(UUID_1))
);
final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
mkEntry(UUID_1, singletonList(TASK_0_0)),
mkEntry(UUID_2, singletonList(TASK_0_2)),
mkEntry(UUID_3, singletonList(TASK_0_1))
);
final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
mkEntry(UUID_1, EMPTY_TASK_LIST),
mkEntry(UUID_2, singletonList(TASK_0_1)),
mkEntry(UUID_3, singletonList(TASK_0_2))
);
client3.addPreviousStandbyTasks(singleton(TASK_0_2));
assertTrue(
assertThat(
"should have assigned movements",
assignTaskMovements(
balancedAssignment,
tasksToCaughtUpClients,
clientStates,
getMapWithNumStandbys(allTasks, 1),
maxWarmupReplicas)
maxWarmupReplicas),
is(true)
);
// Even though we have no warmups allowed, we still let client1 take over active processing while
// client2 "warms up" because client1 was a caught-up standby, so it can "trade" standby status with
// the not-caught-up active client2.
verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
}
// I.e., when you have a caught-up standby and a not-caught-up active, you can just swap their roles
// and not call it a "warmup".
assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0)));
assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet()));
assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_0)));
private void verifyClientStateAssignments(final Map<UUID, List<TaskId>> expectedActiveTaskAssignment,
final Map<UUID, List<TaskId>> expectedStandbyTaskAssignment) {
for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
final UUID client = clientEntry.getKey();
final ClientState state = clientEntry.getValue();
assertThat(state.activeTasks(), equalTo(new HashSet<>(expectedActiveTaskAssignment.get(client))));
assertThat(state.standbyTasks(), equalTo(new HashSet<>(expectedStandbyTaskAssignment.get(client))));
}
}
private static Map<TaskId, Integer> getMapWithNumStandbys(final Set<TaskId> tasks, final int numStandbys) {
return tasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbys));
private static ClientState getClientStateWithActiveAssignment(final Collection<TaskId> activeTasks) {
final ClientState client1 = new ClientState(1);
client1.assignActiveTasks(activeTasks);
return client1;
}
}

126
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueueTest.java

@ -1,126 +0,0 @@ @@ -1,126 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertNull;
import java.util.Map;
import java.util.UUID;
import java.util.function.BiFunction;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.Test;
public class ValidClientsByTaskLoadQueueTest {
private static final TaskId DUMMY_TASK = new TaskId(0, 0);
private final ClientState client1 = new ClientState(1);
private final ClientState client2 = new ClientState(1);
private final ClientState client3 = new ClientState(1);
private final BiFunction<UUID, TaskId, Boolean> alwaysTrue = (client, task) -> true;
private final BiFunction<UUID, TaskId, Boolean> alwaysFalse = (client, task) -> false;
private ValidClientsByTaskLoadQueue queue;
private Map<UUID, ClientState> clientStates;
@Test
public void shouldReturnOnlyClient() {
clientStates = getClientStatesMap(client1);
queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
queue.offerAll(clientStates.keySet());
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
}
@Test
public void shouldReturnNull() {
clientStates = getClientStatesMap(client1);
queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysFalse);
queue.offerAll(clientStates.keySet());
assertNull(queue.poll(DUMMY_TASK));
}
@Test
public void shouldReturnLeastLoadedClient() {
clientStates = getClientStatesMap(client1, client2, client3);
queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
client1.assignActive(TASK_0_0);
client2.assignActiveTasks(asList(TASK_0_1, TASK_1_1));
client3.assignActiveTasks(asList(TASK_0_2, TASK_1_2, TASK_2_2));
queue.offerAll(clientStates.keySet());
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2));
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_3));
}
@Test
public void shouldNotRetainDuplicates() {
clientStates = getClientStatesMap(client1);
queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
queue.offerAll(clientStates.keySet());
queue.offer(UUID_1);
assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
assertNull(queue.poll(DUMMY_TASK));
}
@Test
public void shouldOnlyReturnValidClients() {
clientStates = getClientStatesMap(client1, client2);
queue = new ValidClientsByTaskLoadQueue(clientStates, (client, task) -> client.equals(UUID_1));
queue.offerAll(clientStates.keySet());
assertThat(queue.poll(DUMMY_TASK, 2), equalTo(singletonList(UUID_1)));
}
@Test
public void shouldReturnUpToNumClients() {
clientStates = getClientStatesMap(client1, client2, client3);
queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
client1.assignActive(TASK_0_0);
client2.assignActiveTasks(asList(TASK_0_1, TASK_1_1));
client3.assignActiveTasks(asList(TASK_0_2, TASK_1_2, TASK_2_2));
queue.offerAll(clientStates.keySet());
assertThat(queue.poll(DUMMY_TASK, 2), equalTo(asList(UUID_1, UUID_2)));
}
}
Loading…
Cancel
Save