Browse Source

KAFKA-15100; KRaft data race with the expiration service (#14141)

The KRaft client uses an expiration service to complete FETCH requests that have timed out. This expiration service uses a different thread from the KRaft polling thread. This means that it is unsafe for the expiration service thread to call tryCompleteFetchRequest. tryCompleteFetchRequest reads and updates a lot of states that is assumed to be only be read and updated from the polling thread.

The KRaft client now does not call tryCompleteFetchRequest when the FETCH request has expired. It instead will send the FETCH response that was computed when the FETCH request was first handled.

This change also fixes a bug where the KRaft client was not sending the FETCH response immediately, if the response contained a diverging epoch or snapshot id.

Reviewers: Jason Gustafson <jason@confluent.io>
pull/14177/head
José Armando García Sancio 1 year ago committed by GitHub
parent
commit
dafe51b658
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      raft/src/main/java/org/apache/kafka/raft/ElectionState.java
  2. 49
      raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
  3. 5
      raft/src/main/java/org/apache/kafka/raft/LeaderState.java
  4. 3
      raft/src/main/java/org/apache/kafka/raft/QuorumState.java
  5. 32
      raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
  6. 30
      raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
  7. 54
      raft/src/test/java/org/apache/kafka/raft/MockLog.java
  8. 8
      raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
  9. 2
      raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java

6
raft/src/main/java/org/apache/kafka/raft/ElectionState.java

@ -61,7 +61,7 @@ public class ElectionState { @@ -61,7 +61,7 @@ public class ElectionState {
public boolean isLeader(int nodeId) {
if (nodeId < 0)
throw new IllegalArgumentException("Invalid negative nodeId: " + nodeId);
return leaderIdOpt.orElse(-1) == nodeId;
return leaderIdOrSentinel() == nodeId;
}
public boolean isVotedCandidate(int nodeId) {
@ -94,6 +94,10 @@ public class ElectionState { @@ -94,6 +94,10 @@ public class ElectionState {
return votedIdOpt.isPresent();
}
public int leaderIdOrSentinel() {
return leaderIdOpt.orElse(-1);
}
@Override
public String toString() {

49
raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java

@ -879,9 +879,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -879,9 +879,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
.setRecords(records)
.setErrorCode(error.code())
.setLogStartOffset(log.startOffset())
.setHighWatermark(highWatermark
.map(offsetMetadata -> offsetMetadata.offset)
.orElse(-1L));
.setHighWatermark(
highWatermark.map(offsetMetadata -> offsetMetadata.offset).orElse(-1L)
);
partitionData.currentLeader()
.setLeaderEpoch(quorum.epoch())
@ -960,8 +960,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -960,8 +960,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
|| fetchPartition.fetchOffset() < 0
|| fetchPartition.lastFetchedEpoch() < 0
|| fetchPartition.lastFetchedEpoch() > fetchPartition.currentLeaderEpoch()) {
return completedFuture(buildEmptyFetchResponse(
Errors.INVALID_REQUEST, Optional.empty()));
return completedFuture(
buildEmptyFetchResponse(Errors.INVALID_REQUEST, Optional.empty())
);
}
int replicaId = FetchRequest.replicaId(request);
@ -971,7 +972,15 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -971,7 +972,15 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
if (partitionResponse.errorCode() != Errors.NONE.code()
|| FetchResponse.recordsSize(partitionResponse) > 0
|| request.maxWaitMs() == 0) {
|| request.maxWaitMs() == 0
|| isPartitionDiverged(partitionResponse)
|| isPartitionSnapshotted(partitionResponse)) {
// Reply immediately if any of the following is true
// 1. The response contains an errror
// 2. There are records in the response
// 3. The fetching replica doesn't want to wait for the partition to contain new data
// 4. The fetching replica needs to truncate because the log diverged
// 5. The fetching replica needs to fetch a snapshot
return completedFuture(response);
}
@ -984,11 +993,16 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -984,11 +993,16 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
Throwable cause = exception instanceof ExecutionException ?
exception.getCause() : exception;
// If the fetch timed out in purgatory, it means no new data is available,
// and we will complete the fetch successfully. Otherwise, if there was
// any other error, we need to return it.
Errors error = Errors.forException(cause);
if (error != Errors.REQUEST_TIMED_OUT) {
if (error == Errors.REQUEST_TIMED_OUT) {
// Note that for this case the calling thread is the expiration service thread and not the
// polling thread.
//
// If the fetch request timed out in purgatory, it means no new data is available,
// just return the original fetch response.
return response;
} else {
// If there was any error other than REQUEST_TIMED_OUT, return it.
logger.info("Failed to handle fetch from {} at {} due to {}",
replicaId, fetchPartition.fetchOffset(), error);
return buildEmptyFetchResponse(error, Optional.empty());
@ -999,6 +1013,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -999,6 +1013,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
logger.trace("Completing delayed fetch from {} starting at offset {} at {}",
replicaId, fetchPartition.fetchOffset(), completionTimeMs);
// It is safe to call tryCompleteFetchRequest because only the polling thread completes this
// future successfully. This is true because only the polling thread appends record batches to
// the log from maybeAppendBatches.
return tryCompleteFetchRequest(replicaId, fetchPartition, time.milliseconds());
});
}
@ -1048,6 +1065,18 @@ public class KafkaRaftClient<T> implements RaftClient<T> { @@ -1048,6 +1065,18 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
}
}
private static boolean isPartitionDiverged(FetchResponseData.PartitionData partitionResponseData) {
FetchResponseData.EpochEndOffset divergingEpoch = partitionResponseData.divergingEpoch();
return divergingEpoch.epoch() != -1 || divergingEpoch.endOffset() != -1;
}
private static boolean isPartitionSnapshotted(FetchResponseData.PartitionData partitionResponseData) {
FetchResponseData.SnapshotId snapshotId = partitionResponseData.snapshotId();
return snapshotId.epoch() != -1 || snapshotId.endOffset() != -1;
}
private static OptionalInt optionalLeaderId(int leaderIdOrNil) {
if (leaderIdOrNil < 0)
return OptionalInt.empty();

5
raft/src/main/java/org/apache/kafka/raft/LeaderState.java

@ -50,7 +50,7 @@ public class LeaderState<T> implements EpochState { @@ -50,7 +50,7 @@ public class LeaderState<T> implements EpochState {
private final long epochStartOffset;
private final Set<Integer> grantingVoters;
private Optional<LogOffsetMetadata> highWatermark;
private Optional<LogOffsetMetadata> highWatermark = Optional.empty();
private final Map<Integer, ReplicaState> voterStates = new HashMap<>();
private final Map<Integer, ReplicaState> observerStates = new HashMap<>();
private final Logger log;
@ -71,7 +71,6 @@ public class LeaderState<T> implements EpochState { @@ -71,7 +71,6 @@ public class LeaderState<T> implements EpochState {
this.localId = localId;
this.epoch = epoch;
this.epochStartOffset = epochStartOffset;
this.highWatermark = Optional.empty();
for (int voterId : voters) {
boolean hasAcknowledgedLeader = voterId == localId;
@ -337,7 +336,7 @@ public class LeaderState<T> implements EpochState { @@ -337,7 +336,7 @@ public class LeaderState<T> implements EpochState {
.setErrorCode(Errors.NONE.code())
.setLeaderId(localId)
.setLeaderEpoch(epoch)
.setHighWatermark(highWatermark().map(offsetMetadata -> offsetMetadata.offset).orElse(-1L))
.setHighWatermark(highWatermark.map(offsetMetadata -> offsetMetadata.offset).orElse(-1L))
.setCurrentVoters(describeReplicaStates(voterStates, currentTimeMs))
.setObservers(describeReplicaStates(observerStates, currentTimeMs));
}

3
raft/src/main/java/org/apache/kafka/raft/QuorumState.java

@ -233,7 +233,7 @@ public class QuorumState { @@ -233,7 +233,7 @@ public class QuorumState {
}
public int leaderIdOrSentinel() {
return leaderId().orElse(-1);
return state.election().leaderIdOrSentinel();
}
public Optional<LogOffsetMetadata> highWatermark() {
@ -570,5 +570,4 @@ public class QuorumState { @@ -570,5 +570,4 @@ public class QuorumState {
public boolean isCandidate() {
return state instanceof CandidateState;
}
}

32
raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java

@ -268,6 +268,38 @@ final public class KafkaRaftClientSnapshotTest { @@ -268,6 +268,38 @@ final public class KafkaRaftClientSnapshotTest {
}
}
@Test
public void testLeaderImmediatelySendsSnapshotId() throws Exception {
int localId = 0;
int otherNodeId = 1;
Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 4);
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withUnknownLeader(snapshotId.epoch())
.appendToLog(snapshotId.epoch(), Arrays.asList("a", "b", "c"))
.appendToLog(snapshotId.epoch(), Arrays.asList("d", "e", "f"))
.appendToLog(snapshotId.epoch(), Arrays.asList("g", "h", "i"))
.withEmptySnapshot(snapshotId)
.deleteBeforeSnapshot(snapshotId)
.build();
context.becomeLeader();
int epoch = context.currentEpoch();
// Send a fetch request for an end offset and epoch which has been snapshotted
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 6, 2, 500));
context.client.poll();
// Expect that the leader replies immediately with a snapshot id
FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse();
assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
assertEquals(localId, partitionResponse.currentLeader().leaderId());
assertEquals(snapshotId.epoch(), partitionResponse.snapshotId().epoch());
assertEquals(snapshotId.offset(), partitionResponse.snapshotId().endOffset());
}
@Test
public void testFetchRequestOffsetLessThanLogStart() throws Exception {
int localId = 0;

30
raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java

@ -1206,6 +1206,36 @@ public class KafkaRaftClientTest { @@ -1206,6 +1206,36 @@ public class KafkaRaftClientTest {
assertEquals(records, context.listener.commitWithLastOffset(offset));
}
@Test
public void testLeaderImmediatelySendsDivergingEpoch() throws Exception {
int localId = 0;
int otherNodeId = 1;
Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withUnknownLeader(5)
.appendToLog(1, Arrays.asList("a", "b", "c"))
.appendToLog(3, Arrays.asList("d", "e", "f"))
.appendToLog(5, Arrays.asList("g", "h", "i"))
.build();
// Start off as the leader
context.becomeLeader();
int epoch = context.currentEpoch();
// Send a fetch request for an end offset and epoch which has diverged
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 6, 2, 500));
context.client.poll();
// Expect that the leader replies immediately with a diverging epoch
FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse();
assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
assertEquals(localId, partitionResponse.currentLeader().leaderId());
assertEquals(1, partitionResponse.divergingEpoch().epoch());
assertEquals(3, partitionResponse.divergingEpoch().endOffset());
}
@Test
public void testCandidateIgnoreVoteRequestOnSameEpoch() throws Exception {
int localId = 0;

54
raft/src/test/java/org/apache/kafka/raft/MockLog.java

@ -85,6 +85,7 @@ public class MockLog implements ReplicatedLog { @@ -85,6 +85,7 @@ public class MockLog implements ReplicatedLog {
" which is below the current high watermark " + highWatermark);
}
logger.debug("Truncating log to end offset {}", offset);
batches.removeIf(entry -> entry.lastOffset() >= offset);
epochStartOffsets.removeIf(epochStartOffset -> epochStartOffset.startOffset >= offset);
firstUnflushedOffset = Math.min(firstUnflushedOffset, endOffset().offset);
@ -98,6 +99,8 @@ public class MockLog implements ReplicatedLog { @@ -98,6 +99,8 @@ public class MockLog implements ReplicatedLog {
(snapshotId.epoch() == logLastFetchedEpoch().orElse(0) &&
snapshotId.offset() > endOffset().offset)) {
logger.debug("Truncating to the latest snapshot at {}", snapshotId);
batches.clear();
epochStartOffsets.clear();
snapshots.headMap(snapshotId, false).clear();
@ -278,10 +281,11 @@ public class MockLog implements ReplicatedLog { @@ -278,10 +281,11 @@ public class MockLog implements ReplicatedLog {
return append(records, OptionalInt.of(epoch));
}
private Long appendBatch(LogBatch batch) {
private long appendBatch(LogBatch batch) {
if (batch.epoch > lastFetchedEpoch()) {
epochStartOffsets.add(new EpochStartOffset(batch.epoch, batch.firstOffset()));
}
batches.add(batch);
return batch.firstOffset();
}
@ -311,15 +315,22 @@ public class MockLog implements ReplicatedLog { @@ -311,15 +315,22 @@ public class MockLog implements ReplicatedLog {
);
}
List<LogEntry> entries = buildEntries(batch, Record::offset);
appendBatch(
new LogBatch(
epoch.orElseGet(batch::partitionLeaderEpoch),
batch.isControlBatch(),
entries
)
LogBatch logBatch = new LogBatch(
epoch.orElseGet(batch::partitionLeaderEpoch),
batch.isControlBatch(),
buildEntries(batch, Record::offset)
);
lastOffset = entries.get(entries.size() - 1).offset;
if (logger.isDebugEnabled()) {
String nodeState = "Follower";
if (epoch.isPresent()) {
nodeState = "Leader";
}
logger.debug("{} appending to the log {}", nodeState, logBatch);
}
appendBatch(logBatch);
lastOffset = logBatch.last().offset;
}
return new LogAppendInfo(baseOffset, lastOffset);
@ -385,13 +396,9 @@ public class MockLog implements ReplicatedLog { @@ -385,13 +396,9 @@ public class MockLog implements ReplicatedLog {
@Override
public LogFetchInfo read(long startOffset, Isolation isolation) {
OptionalLong maxOffsetOpt = isolation == Isolation.COMMITTED ?
OptionalLong.of(highWatermark.offset) :
OptionalLong.empty();
verifyOffsetInRange(startOffset);
long maxOffset = maxOffsetOpt.orElse(endOffset().offset);
long maxOffset = isolation == Isolation.COMMITTED ? highWatermark.offset : endOffset().offset;
if (startOffset >= maxOffset) {
return new LogFetchInfo(MemoryRecords.EMPTY, new LogOffsetMetadata(
startOffset, metadataForOffset(startOffset)));
@ -401,6 +408,13 @@ public class MockLog implements ReplicatedLog { @@ -401,6 +408,13 @@ public class MockLog implements ReplicatedLog {
int batchCount = 0;
LogOffsetMetadata batchStartOffset = null;
logger.debug(
"Looking for a batch that starts at {} and ends at {} for isolation {}",
startOffset,
maxOffset,
isolation
);
for (LogBatch batch : batches) {
// Note that start offset is inclusive while max offset is exclusive. We only return
// complete batches, so batches which end at an offset larger than the max offset are
@ -541,6 +555,7 @@ public class MockLog implements ReplicatedLog { @@ -541,6 +555,7 @@ public class MockLog implements ReplicatedLog {
if (snapshots.containsKey(snapshotId)) {
snapshots.headMap(snapshotId, false).clear();
logger.debug("Deleting batches included in the snapshot {}", snapshotId);
batches.removeIf(entry -> entry.lastOffset() < snapshotId.offset());
AtomicReference<Optional<EpochStartOffset>> last = new AtomicReference<>(Optional.empty());
@ -566,6 +581,17 @@ public class MockLog implements ReplicatedLog { @@ -566,6 +581,17 @@ public class MockLog implements ReplicatedLog {
return updated;
}
@Override
public String toString() {
return String.format(
"MockLog(epochStartOffsets=%s, batches=%s, snapshots=%s, highWatermark=%s",
epochStartOffsets,
batches,
snapshots,
highWatermark
);
}
static class MockOffsetMetadata implements OffsetMetadata {
final long id;

8
raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java

@ -385,9 +385,7 @@ public final class RaftClientTestContext { @@ -385,9 +385,7 @@ public final class RaftClientTestContext {
return new LeaderAndEpoch(election.leaderIdOpt, election.epoch);
}
void expectAndGrantVotes(
int epoch
) throws Exception {
void expectAndGrantVotes(int epoch) throws Exception {
pollUntilRequest();
List<RaftRequest.Outbound> voteRequests = collectVoteRequests(epoch,
@ -406,9 +404,7 @@ public final class RaftClientTestContext { @@ -406,9 +404,7 @@ public final class RaftClientTestContext {
return localId.orElseThrow(() -> new AssertionError("Required local id is not defined"));
}
private void expectBeginEpoch(
int epoch
) throws Exception {
private void expectBeginEpoch(int epoch) throws Exception {
pollUntilRequest();
for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) {
BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow());

2
raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java

@ -38,8 +38,8 @@ import org.apache.kafka.raft.MockLog.LogBatch; @@ -38,8 +38,8 @@ import org.apache.kafka.raft.MockLog.LogBatch;
import org.apache.kafka.raft.MockLog.LogEntry;
import org.apache.kafka.raft.internals.BatchMemoryPool;
import org.apache.kafka.server.common.serialization.RecordSerde;
import org.apache.kafka.snapshot.SnapshotReader;
import org.apache.kafka.snapshot.RecordsSnapshotReader;
import org.apache.kafka.snapshot.SnapshotReader;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;

Loading…
Cancel
Save