Browse Source

MINOR: Rewrite/Move KafkaNetworkChannel to the `raft` module (#14559)

This is now possible since `InterBrokerSend` was moved from `core` to `server-common`.
Also rewrite/move `KafkaNetworkChannelTest`.

The scala version of `KafkaNetworkChannelTest` passed with the changes here (before I
deleted it).

Reviewers: Justine Olshan <jolshan@confluent.io>, José Armando García Sancio <jsancio@users.noreply.github.com>
pull/14569/head
Ismael Juma 11 months ago committed by GitHub
parent
commit
69e591db3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      checkstyle/import-control.xml
  2. 191
      core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
  3. 2
      core/src/main/scala/kafka/raft/RaftManager.scala
  4. 316
      core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
  5. 183
      raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java
  6. 6
      raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
  7. 323
      raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java

1
checkstyle/import-control.xml

@ -406,6 +406,7 @@
<allow pkg="org.apache.kafka.common.protocol" /> <allow pkg="org.apache.kafka.common.protocol" />
<allow pkg="org.apache.kafka.server.common" /> <allow pkg="org.apache.kafka.server.common" />
<allow pkg="org.apache.kafka.server.common.serialization" /> <allow pkg="org.apache.kafka.server.common.serialization" />
<allow pkg="org.apache.kafka.server.util" />
<allow pkg="org.apache.kafka.test"/> <allow pkg="org.apache.kafka.test"/>
<allow pkg="com.fasterxml.jackson" /> <allow pkg="com.fasterxml.jackson" />
<allow pkg="net.jqwik"/> <allow pkg="net.jqwik"/>

191
core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala

@ -1,191 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.raft
import kafka.utils.Logging
import org.apache.kafka.clients.{ClientResponse, KafkaClient}
import org.apache.kafka.common.Node
import org.apache.kafka.common.message._
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
import org.apache.kafka.common.requests._
import org.apache.kafka.common.utils.Time
import org.apache.kafka.raft.RaftConfig.InetAddressSpec
import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil}
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
import java.util
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
object KafkaNetworkChannel {
private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = {
requestData match {
case voteRequest: VoteRequestData =>
new VoteRequest.Builder(voteRequest)
case beginEpochRequest: BeginQuorumEpochRequestData =>
new BeginQuorumEpochRequest.Builder(beginEpochRequest)
case endEpochRequest: EndQuorumEpochRequestData =>
new EndQuorumEpochRequest.Builder(endEpochRequest)
case fetchRequest: FetchRequestData =>
new FetchRequest.SimpleBuilder(fetchRequest)
case fetchSnapshotRequest: FetchSnapshotRequestData =>
new FetchSnapshotRequest.Builder(fetchSnapshotRequest)
case _ =>
throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData")
}
}
}
private[raft] class RaftSendThread(
name: String,
networkClient: KafkaClient,
requestTimeoutMs: Int,
time: Time,
isInterruptible: Boolean = true
) extends InterBrokerSendThread(
name,
networkClient,
requestTimeoutMs,
time,
isInterruptible
) {
private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]()
def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
val list = new util.ArrayList[RequestAndCompletionHandler]()
while (true) {
val request = queue.poll()
if (request == null) {
return list
} else {
list.add(request)
}
}
list
}
def sendRequest(request: RequestAndCompletionHandler): Unit = {
queue.add(request)
wakeup()
}
}
class KafkaNetworkChannel(
time: Time,
client: KafkaClient,
requestTimeoutMs: Int,
threadNamePrefix: String
) extends NetworkChannel with Logging {
import KafkaNetworkChannel._
type ResponseHandler = AbstractResponse => Unit
private val correlationIdCounter = new AtomicInteger(0)
private val endpoints = mutable.HashMap.empty[Int, Node]
private val requestThread = new RaftSendThread(
name = threadNamePrefix + "-outbound-request-thread",
networkClient = client,
requestTimeoutMs = requestTimeoutMs,
time = time,
isInterruptible = false
)
override def send(request: RaftRequest.Outbound): Unit = {
def completeFuture(message: ApiMessage): Unit = {
val response = new RaftResponse.Inbound(
request.correlationId,
message,
request.destinationId
)
request.completion.complete(response)
}
def onComplete(clientResponse: ClientResponse): Unit = {
val response = if (clientResponse.versionMismatch != null) {
error(s"Request $request failed due to unsupported version error",
clientResponse.versionMismatch)
errorResponse(request.data, Errors.UNSUPPORTED_VERSION)
} else if (clientResponse.authenticationException != null) {
// For now we treat authentication errors as retriable. We use the
// `NETWORK_EXCEPTION` error code for lack of a good alternative.
// Note that `NodeToControllerChannelManager` will still log the
// authentication errors so that users have a chance to fix the problem.
error(s"Request $request failed due to authentication error",
clientResponse.authenticationException)
errorResponse(request.data, Errors.NETWORK_EXCEPTION)
} else if (clientResponse.wasDisconnected()) {
errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)
} else {
clientResponse.responseBody.data
}
completeFuture(response)
}
endpoints.get(request.destinationId) match {
case Some(node) =>
requestThread.sendRequest(new RequestAndCompletionHandler(
request.createdTimeMs,
node,
buildRequest(request.data),
onComplete
))
case None =>
completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE))
}
}
// Visible for testing
private[raft] def pollOnce(): Unit = {
requestThread.doWork()
}
override def newCorrelationId(): Int = {
correlationIdCounter.getAndIncrement()
}
private def errorResponse(
request: ApiMessage,
error: Errors
): ApiMessage = {
val apiKey = ApiKeys.forId(request.apiKey)
RaftUtil.errorResponse(apiKey, error)
}
override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = {
val node = new Node(id, spec.address.getHostString, spec.address.getPort)
endpoints.put(id, node)
}
def start(): Unit = {
requestThread.start()
}
def initiateShutdown(): Unit = {
requestThread.initiateShutdown()
}
override def close(): Unit = {
requestThread.shutdown()
}
}

2
core/src/main/scala/kafka/raft/RaftManager.scala

@ -42,7 +42,7 @@ import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.security.auth.SecurityProtocol import org.apache.kafka.common.security.auth.SecurityProtocol
import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, NON_ROUTABLE_ADDRESS, UnknownAddressSpec} import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, NON_ROUTABLE_ADDRESS, UnknownAddressSpec}
import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog} import org.apache.kafka.raft.{FileBasedStateStore, KafkaNetworkChannel, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog}
import org.apache.kafka.server.common.serialization.RecordSerde import org.apache.kafka.server.common.serialization.RecordSerde
import org.apache.kafka.server.util.{KafkaScheduler, ShutdownableThread} import org.apache.kafka.server.util.{KafkaScheduler, ShutdownableThread}
import org.apache.kafka.server.fault.FaultHandler import org.apache.kafka.server.fault.FaultHandler

316
core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala

@ -1,316 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.raft
import java.net.InetSocketAddress
import java.util
import java.util.Collections
import org.apache.kafka.clients.MockClient.MockMetadataUpdater
import org.apache.kafka.clients.{MockClient, NodeApiVersions}
import org.apache.kafka.common.message.FetchRequestData.ReplicaState
import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData}
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, VoteRequest, VoteResponse}
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource
import org.apache.kafka.common.utils.{MockTime, Time}
import org.apache.kafka.common.{Node, TopicPartition, Uuid}
import org.apache.kafka.raft.RaftConfig.InetAddressSpec
import org.apache.kafka.raft.{RaftRequest, RaftUtil}
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.{BeforeEach, Test}
import org.junit.jupiter.params.ParameterizedTest
import scala.jdk.CollectionConverters._
class KafkaNetworkChannelTest {
import KafkaNetworkChannelTest._
private val clusterId = "clusterId"
private val requestTimeoutMs = 30000
private val time = new MockTime()
private val client = new MockClient(time, new StubMetadataUpdater)
private val topicPartition = new TopicPartition("topic", 0)
private val topicId = Uuid.randomUuid()
private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft")
@BeforeEach
def setupSupportedApis(): Unit = {
val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion)
client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava))
}
@Test
def testSendToUnknownDestination(): Unit = {
val destinationId = 2
assertBrokerNotAvailable(destinationId)
}
@Test
def testSendToBlackedOutDestination(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
client.backoff(destinationNode, 500)
assertBrokerNotAvailable(destinationId)
}
@Test
def testWakeupClientOnSend(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
client.enableBlockingUntilWakeup(1)
val ioThread = new Thread() {
override def run(): Unit = {
// Block in poll until we get the expected wakeup
channel.pollOnce()
// Poll a second time to send request and receive response
channel.pollOnce()
}
}
val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST))
client.prepareResponseFrom(response, destinationNode, false)
ioThread.start()
val request = sendTestRequest(ApiKeys.FETCH, destinationId)
ioThread.join()
assertResponseCompleted(request, Errors.INVALID_REQUEST)
}
@Test
def testSendAndDisconnect(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
for (apiKey <- RaftApis) {
val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST))
client.prepareResponseFrom(response, destinationNode, true)
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE)
}
}
@Test
def testSendAndFailAuthentication(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
for (apiKey <- RaftApis) {
client.createPendingAuthenticationError(destinationNode, 100)
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION)
// reset to clear backoff time
client.reset()
}
}
private def assertBrokerNotAvailable(destinationId: Int): Unit = {
for (apiKey <- RaftApis) {
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE)
}
}
@Test
def testSendAndReceiveOutboundRequest(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
for (apiKey <- RaftApis) {
val expectedError = Errors.INVALID_REQUEST
val response = buildResponse(buildTestErrorResponse(apiKey, expectedError))
client.prepareResponseFrom(response, destinationNode)
sendAndAssertErrorResponse(apiKey, destinationId, expectedError)
}
}
@Test
def testUnsupportedVersionError(): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
for (apiKey <- RaftApis) {
client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey)
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION)
}
}
@ParameterizedTest
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
def testFetchRequestDowngrade(version: Short): Unit = {
val destinationId = 2
val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
channel.updateEndpoint(destinationId, new InetAddressSpec(
new InetSocketAddress(destinationNode.host, destinationNode.port)))
sendTestRequest(ApiKeys.FETCH, destinationId)
channel.pollOnce()
assertEquals(1, client.requests.size)
val request = client.requests.peek.requestBuilder.build(version)
if (version < 15) {
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1)
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1)
} else {
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1)
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1)
}
}
private def sendTestRequest(
apiKey: ApiKeys,
destinationId: Int,
): RaftRequest.Outbound = {
val correlationId = channel.newCorrelationId()
val createdTimeMs = time.milliseconds()
val apiRequest = buildTestRequest(apiKey)
val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs)
channel.send(request)
request
}
private def assertResponseCompleted(
request: RaftRequest.Outbound,
expectedError: Errors
): Unit = {
assertTrue(request.completion.isDone)
val response = request.completion.get()
assertEquals(request.destinationId, response.sourceId)
assertEquals(request.correlationId, response.correlationId)
assertEquals(request.data.apiKey, response.data.apiKey)
assertEquals(expectedError, extractError(response.data))
}
private def sendAndAssertErrorResponse(
apiKey: ApiKeys,
destinationId: Int,
error: Errors
): Unit = {
val request = sendTestRequest(apiKey, destinationId)
channel.pollOnce()
assertResponseCompleted(request, error)
}
private def buildTestRequest(key: ApiKeys): ApiMessage = {
val leaderEpoch = 5
val leaderId = 1
key match {
case ApiKeys.BEGIN_QUORUM_EPOCH =>
BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId)
case ApiKeys.END_QUORUM_EPOCH =>
EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId,
leaderEpoch, Collections.singletonList(2))
case ApiKeys.VOTE =>
val lastEpoch = 4
VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329)
case ApiKeys.FETCH =>
val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => {
fetchPartition
.setCurrentLeaderEpoch(5)
.setFetchOffset(333)
.setLastFetchedEpoch(5)
})
request.setReplicaState(new ReplicaState().setReplicaId(1))
case _ =>
throw new AssertionError(s"Unexpected api $key")
}
}
private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = {
key match {
case ApiKeys.BEGIN_QUORUM_EPOCH =>
new BeginQuorumEpochResponseData()
.setErrorCode(error.code)
case ApiKeys.END_QUORUM_EPOCH =>
new EndQuorumEpochResponseData()
.setErrorCode(error.code)
case ApiKeys.VOTE =>
VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
case ApiKeys.FETCH =>
new FetchResponseData()
.setErrorCode(error.code)
case _ =>
throw new AssertionError(s"Unexpected api $key")
}
}
private def extractError(response: ApiMessage): Errors = {
val code = (response: @unchecked) match {
case res: BeginQuorumEpochResponseData => res.errorCode
case res: EndQuorumEpochResponseData => res.errorCode
case res: FetchResponseData => res.errorCode
case res: VoteResponseData => res.errorCode
}
Errors.forCode(code)
}
def buildResponse(responseData: ApiMessage): AbstractResponse = {
responseData match {
case voteResponse: VoteResponseData =>
new VoteResponse(voteResponse)
case beginEpochResponse: BeginQuorumEpochResponseData =>
new BeginQuorumEpochResponse(beginEpochResponse)
case endEpochResponse: EndQuorumEpochResponseData =>
new EndQuorumEpochResponse(endEpochResponse)
case fetchResponse: FetchResponseData =>
new FetchResponse(fetchResponse)
case _ =>
throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData")
}
}
}
object KafkaNetworkChannelTest {
val RaftApis = Seq(
ApiKeys.VOTE,
ApiKeys.BEGIN_QUORUM_EPOCH,
ApiKeys.END_QUORUM_EPOCH,
ApiKeys.FETCH,
)
private class StubMetadataUpdater extends MockMetadataUpdater {
override def fetchNodes(): util.List[Node] = Collections.emptyList()
override def isUpdateNeeded: Boolean = false
override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {}
}
}

183
raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java

@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.raft;
import org.apache.kafka.clients.ClientResponse;
import org.apache.kafka.clients.KafkaClient;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.message.BeginQuorumEpochRequestData;
import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchSnapshotRequestData;
import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
import org.apache.kafka.common.requests.EndQuorumEpochRequest;
import org.apache.kafka.common.requests.FetchRequest;
import org.apache.kafka.common.requests.FetchSnapshotRequest;
import org.apache.kafka.common.requests.VoteRequest;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.server.util.InterBrokerSendThread;
import org.apache.kafka.server.util.RequestAndCompletionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
public class KafkaNetworkChannel implements NetworkChannel {
static class SendThread extends InterBrokerSendThread {
private Queue<RequestAndCompletionHandler> queue = new ConcurrentLinkedQueue<>();
public SendThread(String name, KafkaClient networkClient, int requestTimeoutMs, Time time, boolean isInterruptible) {
super(name, networkClient, requestTimeoutMs, time, isInterruptible);
}
@Override
public Collection<RequestAndCompletionHandler> generateRequests() {
List<RequestAndCompletionHandler> list = new ArrayList<>();
while (true) {
RequestAndCompletionHandler request = queue.poll();
if (request == null) {
return list;
} else {
list.add(request);
}
}
}
public void sendRequest(RequestAndCompletionHandler request) {
queue.add(request);
wakeup();
}
}
private static final Logger log = LoggerFactory.getLogger(KafkaNetworkChannel.class);
private final SendThread requestThread;
private final AtomicInteger correlationIdCounter = new AtomicInteger(0);
private final Map<Integer, Node> endpoints = new HashMap<>();
public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) {
this.requestThread = new SendThread(
threadNamePrefix + "-outbound-request-thread",
client,
requestTimeoutMs,
time,
false
);
}
@Override
public int newCorrelationId() {
return correlationIdCounter.getAndIncrement();
}
@Override
public void send(RaftRequest.Outbound request) {
Node node = endpoints.get(request.destinationId());
if (node != null) {
requestThread.sendRequest(new RequestAndCompletionHandler(
request.createdTimeMs,
node,
buildRequest(request.data),
response -> sendOnComplete(request, response)
));
} else
sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE));
}
private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) {
RaftResponse.Inbound response = new RaftResponse.Inbound(
request.correlationId,
message,
request.destinationId()
);
request.completion.complete(response);
}
private void sendOnComplete(RaftRequest.Outbound request, ClientResponse clientResponse) {
ApiMessage response;
if (clientResponse.versionMismatch() != null) {
log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch());
response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION);
} else if (clientResponse.authenticationException() != null) {
// For now we treat authentication errors as retriable. We use the
// `NETWORK_EXCEPTION` error code for lack of a good alternative.
// Note that `NodeToControllerChannelManager` will still log the
// authentication errors so that users have a chance to fix the problem.
log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException());
response = errorResponse(request.data, Errors.NETWORK_EXCEPTION);
} else if (clientResponse.wasDisconnected()) {
response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE);
} else {
response = clientResponse.responseBody().data();
}
sendCompleteFuture(request, response);
}
private ApiMessage errorResponse(ApiMessage request, Errors error) {
ApiKeys apiKey = ApiKeys.forId(request.apiKey());
return RaftUtil.errorResponse(apiKey, error);
}
@Override
public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) {
Node node = new Node(id, spec.address.getHostString(), spec.address.getPort());
endpoints.put(id, node);
}
public void start() {
requestThread.start();
}
@Override
public void close() throws InterruptedException {
requestThread.shutdown();
}
// Visible for testing
public void pollOnce() {
requestThread.doWork();
}
static AbstractRequest.Builder<? extends AbstractRequest> buildRequest(ApiMessage requestData) {
if (requestData instanceof VoteRequestData)
return new VoteRequest.Builder((VoteRequestData) requestData);
if (requestData instanceof BeginQuorumEpochRequestData)
return new BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData);
if (requestData instanceof EndQuorumEpochRequestData)
return new EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData);
if (requestData instanceof FetchRequestData)
return new FetchRequest.SimpleBuilder((FetchRequestData) requestData);
if (requestData instanceof FetchSnapshotRequestData)
return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) requestData);
throw new IllegalArgumentException("Unexpected type for requestData: " + requestData);
}
}

6
raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java

@ -16,13 +16,11 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import java.io.Closeable;
/** /**
* A simple network interface with few assumptions. We do not assume ordering * A simple network interface with few assumptions. We do not assume ordering
* of requests or even that every outbound request will receive a response. * of requests or even that every outbound request will receive a response.
*/ */
public interface NetworkChannel extends Closeable { public interface NetworkChannel extends AutoCloseable {
/** /**
* Generate a new and unique correlationId for a new request to be sent. * Generate a new and unique correlationId for a new request to be sent.
@ -41,6 +39,6 @@ public interface NetworkChannel extends Closeable {
*/ */
void updateEndpoint(int id, RaftConfig.InetAddressSpec address); void updateEndpoint(int id, RaftConfig.InetAddressSpec address);
default void close() {} default void close() throws InterruptedException {}
} }

323
raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java

@ -0,0 +1,323 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.raft;
import org.apache.kafka.clients.MockClient;
import org.apache.kafka.clients.NodeApiVersions;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.message.ApiVersionsResponseData;
import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.requests.AbstractResponse;
import org.apache.kafka.common.requests.ApiVersionsResponse;
import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
import org.apache.kafka.common.requests.BeginQuorumEpochResponse;
import org.apache.kafka.common.requests.EndQuorumEpochRequest;
import org.apache.kafka.common.requests.EndQuorumEpochResponse;
import org.apache.kafka.common.requests.FetchRequest;
import org.apache.kafka.common.requests.FetchResponse;
import org.apache.kafka.common.requests.VoteRequest;
import org.apache.kafka.common.requests.VoteResponse;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class KafkaNetworkChannelTest {
private static class StubMetadataUpdater implements MockClient.MockMetadataUpdater {
@Override
public List<Node> fetchNodes() {
return Collections.emptyList();
}
@Override
public boolean isUpdateNeeded() {
return false;
}
@Override
public void update(Time time, MockClient.MetadataUpdate update) { }
}
private static final List<ApiKeys> RAFT_APIS = asList(
ApiKeys.VOTE,
ApiKeys.BEGIN_QUORUM_EPOCH,
ApiKeys.END_QUORUM_EPOCH,
ApiKeys.FETCH
);
private final String clusterId = "clusterId";
private final int requestTimeoutMs = 30000;
private final Time time = new MockTime();
private final MockClient client = new MockClient(time, new StubMetadataUpdater());
private final TopicPartition topicPartition = new TopicPartition("topic", 0);
private final Uuid topicId = Uuid.randomUuid();
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft");
@BeforeEach
public void setupSupportedApis() {
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS.stream().map(
ApiVersionsResponse::toApiVersion).collect(Collectors.toList());
client.setNodeApiVersions(NodeApiVersions.create(supportedApis));
}
@Test
public void testSendToUnknownDestination() throws ExecutionException, InterruptedException {
int destinationId = 2;
assertBrokerNotAvailable(destinationId);
}
@Test
public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
client.backoff(destinationNode, 500);
assertBrokerNotAvailable(destinationId);
}
@Test
public void testWakeupClientOnSend() throws InterruptedException, ExecutionException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
client.enableBlockingUntilWakeup(1);
Thread ioThread = new Thread(() -> {
// Block in poll until we get the expected wakeup
channel.pollOnce();
// Poll a second time to send request and receive response
channel.pollOnce();
});
AbstractResponse response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST));
client.prepareResponseFrom(response, destinationNode, false);
ioThread.start();
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId);
ioThread.join();
assertResponseCompleted(request, Errors.INVALID_REQUEST);
}
@Test
public void testSendAndDisconnect() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
for (ApiKeys apiKey : RAFT_APIS) {
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST));
client.prepareResponseFrom(response, destinationNode, true);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
}
}
@Test
public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
for (ApiKeys apiKey : RAFT_APIS) {
client.createPendingAuthenticationError(destinationNode, 100);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION);
// reset to clear backoff time
client.reset();
}
}
private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException {
for (ApiKeys apiKey : RAFT_APIS) {
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
}
}
@Test
public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
for (ApiKeys apiKey : RAFT_APIS) {
Errors expectedError = Errors.INVALID_REQUEST;
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError));
client.prepareResponseFrom(response, destinationNode);
System.out.println("api key " + apiKey + ", response " + response);
sendAndAssertErrorResponse(apiKey, destinationId, expectedError);
}
}
@Test
public void testUnsupportedVersionError() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
for (ApiKeys apiKey : RAFT_APIS) {
client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION);
}
}
@ParameterizedTest
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
public void testFetchRequestDowngrade(short version) {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
new InetSocketAddress(destinationNode.host(), destinationNode.port())));
sendTestRequest(ApiKeys.FETCH, destinationId);
channel.pollOnce();
assertEquals(1, client.requests().size());
AbstractRequest request = client.requests().peek().requestBuilder().build(version);
if (version < 15) {
assertTrue(((FetchRequest) request).data().replicaId() == 1);
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == -1);
} else {
assertTrue(((FetchRequest) request).data().replicaId() == -1);
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == 1);
}
}
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) {
int correlationId = channel.newCorrelationId();
long createdTimeMs = time.milliseconds();
ApiMessage apiRequest = buildTestRequest(apiKey);
RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs);
channel.send(request);
return request;
}
private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException {
assertTrue(request.completion.isDone());
RaftResponse.Inbound response = request.completion.get();
assertEquals(request.destinationId(), response.sourceId());
assertEquals(request.correlationId, response.correlationId);
assertEquals(request.data.apiKey(), response.data.apiKey());
assertEquals(expectedError, extractError(response.data));
}
private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException {
RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId);
channel.pollOnce();
assertResponseCompleted(request, error);
}
private ApiMessage buildTestRequest(ApiKeys key) {
int leaderEpoch = 5;
int leaderId = 1;
switch (key) {
case BEGIN_QUORUM_EPOCH:
return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId);
case END_QUORUM_EPOCH:
return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch,
Collections.singletonList(2));
case VOTE:
int lastEpoch = 4;
return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329);
case FETCH:
FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> {
fetchPartition
.setCurrentLeaderEpoch(5)
.setFetchOffset(333)
.setLastFetchedEpoch(5);
});
request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1));
return request;
default:
throw new AssertionError("Unexpected api " + key);
}
}
private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) {
switch (key) {
case BEGIN_QUORUM_EPOCH:
return new BeginQuorumEpochResponseData().setErrorCode(error.code());
case END_QUORUM_EPOCH:
return new EndQuorumEpochResponseData().setErrorCode(error.code());
case VOTE:
return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
case FETCH:
return new FetchResponseData().setErrorCode(error.code());
default:
throw new AssertionError("Unexpected api " + key);
}
}
private Errors extractError(ApiMessage response) {
short code;
if (response instanceof BeginQuorumEpochResponseData)
code = ((BeginQuorumEpochResponseData) response).errorCode();
else if (response instanceof EndQuorumEpochResponseData)
code = ((EndQuorumEpochResponseData) response).errorCode();
else if (response instanceof FetchResponseData)
code = ((FetchResponseData) response).errorCode();
else if (response instanceof VoteResponseData)
code = ((VoteResponseData) response).errorCode();
else
throw new IllegalArgumentException("Unexpected type for responseData: " + response);
return Errors.forCode(code);
}
private AbstractResponse buildResponse(ApiMessage responseData) {
if (responseData instanceof VoteResponseData)
return new VoteResponse((VoteResponseData) responseData);
if (responseData instanceof BeginQuorumEpochResponseData)
return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData);
if (responseData instanceof EndQuorumEpochResponseData)
return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData);
if (responseData instanceof FetchResponseData)
return new FetchResponse((FetchResponseData) responseData);
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData);
}
}
Loading…
Cancel
Save