diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml
index 888e8a41ae8..42488c3225f 100644
--- a/checkstyle/import-control.xml
+++ b/checkstyle/import-control.xml
@@ -406,6 +406,7 @@
+
diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
deleted file mode 100644
index 7c00961d1dc..00000000000
--- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.raft
-
-import kafka.utils.Logging
-import org.apache.kafka.clients.{ClientResponse, KafkaClient}
-import org.apache.kafka.common.Node
-import org.apache.kafka.common.message._
-import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
-import org.apache.kafka.common.requests._
-import org.apache.kafka.common.utils.Time
-import org.apache.kafka.raft.RaftConfig.InetAddressSpec
-import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil}
-import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
-
-import java.util
-import java.util.concurrent.ConcurrentLinkedQueue
-import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable
-
-object KafkaNetworkChannel {
-
- private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = {
- requestData match {
- case voteRequest: VoteRequestData =>
- new VoteRequest.Builder(voteRequest)
- case beginEpochRequest: BeginQuorumEpochRequestData =>
- new BeginQuorumEpochRequest.Builder(beginEpochRequest)
- case endEpochRequest: EndQuorumEpochRequestData =>
- new EndQuorumEpochRequest.Builder(endEpochRequest)
- case fetchRequest: FetchRequestData =>
- new FetchRequest.SimpleBuilder(fetchRequest)
- case fetchSnapshotRequest: FetchSnapshotRequestData =>
- new FetchSnapshotRequest.Builder(fetchSnapshotRequest)
- case _ =>
- throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData")
- }
- }
-
-}
-
-private[raft] class RaftSendThread(
- name: String,
- networkClient: KafkaClient,
- requestTimeoutMs: Int,
- time: Time,
- isInterruptible: Boolean = true
-) extends InterBrokerSendThread(
- name,
- networkClient,
- requestTimeoutMs,
- time,
- isInterruptible
-) {
- private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]()
-
- def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
- val list = new util.ArrayList[RequestAndCompletionHandler]()
- while (true) {
- val request = queue.poll()
- if (request == null) {
- return list
- } else {
- list.add(request)
- }
- }
- list
- }
-
- def sendRequest(request: RequestAndCompletionHandler): Unit = {
- queue.add(request)
- wakeup()
- }
-
-}
-
-
-class KafkaNetworkChannel(
- time: Time,
- client: KafkaClient,
- requestTimeoutMs: Int,
- threadNamePrefix: String
-) extends NetworkChannel with Logging {
- import KafkaNetworkChannel._
-
- type ResponseHandler = AbstractResponse => Unit
-
- private val correlationIdCounter = new AtomicInteger(0)
- private val endpoints = mutable.HashMap.empty[Int, Node]
-
- private val requestThread = new RaftSendThread(
- name = threadNamePrefix + "-outbound-request-thread",
- networkClient = client,
- requestTimeoutMs = requestTimeoutMs,
- time = time,
- isInterruptible = false
- )
-
- override def send(request: RaftRequest.Outbound): Unit = {
- def completeFuture(message: ApiMessage): Unit = {
- val response = new RaftResponse.Inbound(
- request.correlationId,
- message,
- request.destinationId
- )
- request.completion.complete(response)
- }
-
- def onComplete(clientResponse: ClientResponse): Unit = {
- val response = if (clientResponse.versionMismatch != null) {
- error(s"Request $request failed due to unsupported version error",
- clientResponse.versionMismatch)
- errorResponse(request.data, Errors.UNSUPPORTED_VERSION)
- } else if (clientResponse.authenticationException != null) {
- // For now we treat authentication errors as retriable. We use the
- // `NETWORK_EXCEPTION` error code for lack of a good alternative.
- // Note that `NodeToControllerChannelManager` will still log the
- // authentication errors so that users have a chance to fix the problem.
- error(s"Request $request failed due to authentication error",
- clientResponse.authenticationException)
- errorResponse(request.data, Errors.NETWORK_EXCEPTION)
- } else if (clientResponse.wasDisconnected()) {
- errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)
- } else {
- clientResponse.responseBody.data
- }
- completeFuture(response)
- }
-
- endpoints.get(request.destinationId) match {
- case Some(node) =>
- requestThread.sendRequest(new RequestAndCompletionHandler(
- request.createdTimeMs,
- node,
- buildRequest(request.data),
- onComplete
- ))
-
- case None =>
- completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE))
- }
- }
-
- // Visible for testing
- private[raft] def pollOnce(): Unit = {
- requestThread.doWork()
- }
-
- override def newCorrelationId(): Int = {
- correlationIdCounter.getAndIncrement()
- }
-
- private def errorResponse(
- request: ApiMessage,
- error: Errors
- ): ApiMessage = {
- val apiKey = ApiKeys.forId(request.apiKey)
- RaftUtil.errorResponse(apiKey, error)
- }
-
- override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = {
- val node = new Node(id, spec.address.getHostString, spec.address.getPort)
- endpoints.put(id, node)
- }
-
- def start(): Unit = {
- requestThread.start()
- }
-
- def initiateShutdown(): Unit = {
- requestThread.initiateShutdown()
- }
-
- override def close(): Unit = {
- requestThread.shutdown()
- }
-}
diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala
index 020477d5a42..f9311d20d95 100644
--- a/core/src/main/scala/kafka/raft/RaftManager.scala
+++ b/core/src/main/scala/kafka/raft/RaftManager.scala
@@ -42,7 +42,7 @@ import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.security.auth.SecurityProtocol
import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, NON_ROUTABLE_ADDRESS, UnknownAddressSpec}
-import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog}
+import org.apache.kafka.raft.{FileBasedStateStore, KafkaNetworkChannel, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog}
import org.apache.kafka.server.common.serialization.RecordSerde
import org.apache.kafka.server.util.{KafkaScheduler, ShutdownableThread}
import org.apache.kafka.server.fault.FaultHandler
diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
deleted file mode 100644
index af230f66553..00000000000
--- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
+++ /dev/null
@@ -1,316 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.raft
-
-import java.net.InetSocketAddress
-import java.util
-import java.util.Collections
-import org.apache.kafka.clients.MockClient.MockMetadataUpdater
-import org.apache.kafka.clients.{MockClient, NodeApiVersions}
-import org.apache.kafka.common.message.FetchRequestData.ReplicaState
-import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData}
-import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
-import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, VoteRequest, VoteResponse}
-import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource
-import org.apache.kafka.common.utils.{MockTime, Time}
-import org.apache.kafka.common.{Node, TopicPartition, Uuid}
-import org.apache.kafka.raft.RaftConfig.InetAddressSpec
-import org.apache.kafka.raft.{RaftRequest, RaftUtil}
-import org.junit.jupiter.api.Assertions._
-import org.junit.jupiter.api.{BeforeEach, Test}
-import org.junit.jupiter.params.ParameterizedTest
-
-import scala.jdk.CollectionConverters._
-
-class KafkaNetworkChannelTest {
- import KafkaNetworkChannelTest._
-
- private val clusterId = "clusterId"
- private val requestTimeoutMs = 30000
- private val time = new MockTime()
- private val client = new MockClient(time, new StubMetadataUpdater)
- private val topicPartition = new TopicPartition("topic", 0)
- private val topicId = Uuid.randomUuid()
- private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft")
-
- @BeforeEach
- def setupSupportedApis(): Unit = {
- val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion)
- client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava))
- }
-
- @Test
- def testSendToUnknownDestination(): Unit = {
- val destinationId = 2
- assertBrokerNotAvailable(destinationId)
- }
-
- @Test
- def testSendToBlackedOutDestination(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
- client.backoff(destinationNode, 500)
- assertBrokerNotAvailable(destinationId)
- }
-
- @Test
- def testWakeupClientOnSend(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
- client.enableBlockingUntilWakeup(1)
-
- val ioThread = new Thread() {
- override def run(): Unit = {
- // Block in poll until we get the expected wakeup
- channel.pollOnce()
-
- // Poll a second time to send request and receive response
- channel.pollOnce()
- }
- }
-
- val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST))
- client.prepareResponseFrom(response, destinationNode, false)
-
- ioThread.start()
- val request = sendTestRequest(ApiKeys.FETCH, destinationId)
-
- ioThread.join()
- assertResponseCompleted(request, Errors.INVALID_REQUEST)
- }
-
- @Test
- def testSendAndDisconnect(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
- for (apiKey <- RaftApis) {
- val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST))
- client.prepareResponseFrom(response, destinationNode, true)
- sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE)
- }
- }
-
- @Test
- def testSendAndFailAuthentication(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
- for (apiKey <- RaftApis) {
- client.createPendingAuthenticationError(destinationNode, 100)
- sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION)
-
- // reset to clear backoff time
- client.reset()
- }
- }
-
- private def assertBrokerNotAvailable(destinationId: Int): Unit = {
- for (apiKey <- RaftApis) {
- sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE)
- }
- }
-
- @Test
- def testSendAndReceiveOutboundRequest(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
- for (apiKey <- RaftApis) {
- val expectedError = Errors.INVALID_REQUEST
- val response = buildResponse(buildTestErrorResponse(apiKey, expectedError))
- client.prepareResponseFrom(response, destinationNode)
- sendAndAssertErrorResponse(apiKey, destinationId, expectedError)
- }
- }
-
- @Test
- def testUnsupportedVersionError(): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
-
- for (apiKey <- RaftApis) {
- client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey)
- sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION)
- }
- }
-
- @ParameterizedTest
- @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
- def testFetchRequestDowngrade(version: Short): Unit = {
- val destinationId = 2
- val destinationNode = new Node(destinationId, "127.0.0.1", 9092)
- channel.updateEndpoint(destinationId, new InetAddressSpec(
- new InetSocketAddress(destinationNode.host, destinationNode.port)))
- sendTestRequest(ApiKeys.FETCH, destinationId)
- channel.pollOnce()
-
- assertEquals(1, client.requests.size)
- val request = client.requests.peek.requestBuilder.build(version)
-
- if (version < 15) {
- assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1)
- assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1)
- } else {
- assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1)
- assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1)
- }
- }
-
- private def sendTestRequest(
- apiKey: ApiKeys,
- destinationId: Int,
- ): RaftRequest.Outbound = {
- val correlationId = channel.newCorrelationId()
- val createdTimeMs = time.milliseconds()
- val apiRequest = buildTestRequest(apiKey)
- val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs)
- channel.send(request)
- request
- }
-
- private def assertResponseCompleted(
- request: RaftRequest.Outbound,
- expectedError: Errors
- ): Unit = {
- assertTrue(request.completion.isDone)
-
- val response = request.completion.get()
- assertEquals(request.destinationId, response.sourceId)
- assertEquals(request.correlationId, response.correlationId)
- assertEquals(request.data.apiKey, response.data.apiKey)
- assertEquals(expectedError, extractError(response.data))
- }
-
- private def sendAndAssertErrorResponse(
- apiKey: ApiKeys,
- destinationId: Int,
- error: Errors
- ): Unit = {
- val request = sendTestRequest(apiKey, destinationId)
- channel.pollOnce()
- assertResponseCompleted(request, error)
- }
-
- private def buildTestRequest(key: ApiKeys): ApiMessage = {
- val leaderEpoch = 5
- val leaderId = 1
- key match {
- case ApiKeys.BEGIN_QUORUM_EPOCH =>
- BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId)
-
- case ApiKeys.END_QUORUM_EPOCH =>
- EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId,
- leaderEpoch, Collections.singletonList(2))
-
- case ApiKeys.VOTE =>
- val lastEpoch = 4
- VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329)
-
- case ApiKeys.FETCH =>
- val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => {
- fetchPartition
- .setCurrentLeaderEpoch(5)
- .setFetchOffset(333)
- .setLastFetchedEpoch(5)
- })
- request.setReplicaState(new ReplicaState().setReplicaId(1))
-
- case _ =>
- throw new AssertionError(s"Unexpected api $key")
- }
- }
-
- private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = {
- key match {
- case ApiKeys.BEGIN_QUORUM_EPOCH =>
- new BeginQuorumEpochResponseData()
- .setErrorCode(error.code)
-
- case ApiKeys.END_QUORUM_EPOCH =>
- new EndQuorumEpochResponseData()
- .setErrorCode(error.code)
-
- case ApiKeys.VOTE =>
- VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
-
- case ApiKeys.FETCH =>
- new FetchResponseData()
- .setErrorCode(error.code)
-
- case _ =>
- throw new AssertionError(s"Unexpected api $key")
- }
- }
-
- private def extractError(response: ApiMessage): Errors = {
- val code = (response: @unchecked) match {
- case res: BeginQuorumEpochResponseData => res.errorCode
- case res: EndQuorumEpochResponseData => res.errorCode
- case res: FetchResponseData => res.errorCode
- case res: VoteResponseData => res.errorCode
- }
- Errors.forCode(code)
- }
-
-
- def buildResponse(responseData: ApiMessage): AbstractResponse = {
- responseData match {
- case voteResponse: VoteResponseData =>
- new VoteResponse(voteResponse)
- case beginEpochResponse: BeginQuorumEpochResponseData =>
- new BeginQuorumEpochResponse(beginEpochResponse)
- case endEpochResponse: EndQuorumEpochResponseData =>
- new EndQuorumEpochResponse(endEpochResponse)
- case fetchResponse: FetchResponseData =>
- new FetchResponse(fetchResponse)
- case _ =>
- throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData")
- }
- }
-
-}
-
-object KafkaNetworkChannelTest {
- val RaftApis = Seq(
- ApiKeys.VOTE,
- ApiKeys.BEGIN_QUORUM_EPOCH,
- ApiKeys.END_QUORUM_EPOCH,
- ApiKeys.FETCH,
- )
-
- private class StubMetadataUpdater extends MockMetadataUpdater {
- override def fetchNodes(): util.List[Node] = Collections.emptyList()
-
- override def isUpdateNeeded: Boolean = false
-
- override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {}
- }
-}
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java
new file mode 100644
index 00000000000..2c0dd25d439
--- /dev/null
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import org.apache.kafka.clients.ClientResponse;
+import org.apache.kafka.clients.KafkaClient;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.message.BeginQuorumEpochRequestData;
+import org.apache.kafka.common.message.EndQuorumEpochRequestData;
+import org.apache.kafka.common.message.FetchRequestData;
+import org.apache.kafka.common.message.FetchSnapshotRequestData;
+import org.apache.kafka.common.message.VoteRequestData;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
+import org.apache.kafka.common.requests.EndQuorumEpochRequest;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchSnapshotRequest;
+import org.apache.kafka.common.requests.VoteRequest;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.server.util.InterBrokerSendThread;
+import org.apache.kafka.server.util.RequestAndCompletionHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class KafkaNetworkChannel implements NetworkChannel {
+
+ static class SendThread extends InterBrokerSendThread {
+
+ private Queue queue = new ConcurrentLinkedQueue<>();
+
+ public SendThread(String name, KafkaClient networkClient, int requestTimeoutMs, Time time, boolean isInterruptible) {
+ super(name, networkClient, requestTimeoutMs, time, isInterruptible);
+ }
+
+ @Override
+ public Collection generateRequests() {
+ List list = new ArrayList<>();
+ while (true) {
+ RequestAndCompletionHandler request = queue.poll();
+ if (request == null) {
+ return list;
+ } else {
+ list.add(request);
+ }
+ }
+ }
+
+ public void sendRequest(RequestAndCompletionHandler request) {
+ queue.add(request);
+ wakeup();
+ }
+ }
+
+ private static final Logger log = LoggerFactory.getLogger(KafkaNetworkChannel.class);
+
+ private final SendThread requestThread;
+
+ private final AtomicInteger correlationIdCounter = new AtomicInteger(0);
+ private final Map endpoints = new HashMap<>();
+
+ public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) {
+ this.requestThread = new SendThread(
+ threadNamePrefix + "-outbound-request-thread",
+ client,
+ requestTimeoutMs,
+ time,
+ false
+ );
+ }
+
+ @Override
+ public int newCorrelationId() {
+ return correlationIdCounter.getAndIncrement();
+ }
+
+ @Override
+ public void send(RaftRequest.Outbound request) {
+ Node node = endpoints.get(request.destinationId());
+ if (node != null) {
+ requestThread.sendRequest(new RequestAndCompletionHandler(
+ request.createdTimeMs,
+ node,
+ buildRequest(request.data),
+ response -> sendOnComplete(request, response)
+ ));
+ } else
+ sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE));
+ }
+
+ private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) {
+ RaftResponse.Inbound response = new RaftResponse.Inbound(
+ request.correlationId,
+ message,
+ request.destinationId()
+ );
+ request.completion.complete(response);
+ }
+
+ private void sendOnComplete(RaftRequest.Outbound request, ClientResponse clientResponse) {
+ ApiMessage response;
+ if (clientResponse.versionMismatch() != null) {
+ log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch());
+ response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION);
+ } else if (clientResponse.authenticationException() != null) {
+ // For now we treat authentication errors as retriable. We use the
+ // `NETWORK_EXCEPTION` error code for lack of a good alternative.
+ // Note that `NodeToControllerChannelManager` will still log the
+ // authentication errors so that users have a chance to fix the problem.
+ log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException());
+ response = errorResponse(request.data, Errors.NETWORK_EXCEPTION);
+ } else if (clientResponse.wasDisconnected()) {
+ response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE);
+ } else {
+ response = clientResponse.responseBody().data();
+ }
+ sendCompleteFuture(request, response);
+ }
+
+ private ApiMessage errorResponse(ApiMessage request, Errors error) {
+ ApiKeys apiKey = ApiKeys.forId(request.apiKey());
+ return RaftUtil.errorResponse(apiKey, error);
+ }
+
+ @Override
+ public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) {
+ Node node = new Node(id, spec.address.getHostString(), spec.address.getPort());
+ endpoints.put(id, node);
+ }
+
+ public void start() {
+ requestThread.start();
+ }
+
+ @Override
+ public void close() throws InterruptedException {
+ requestThread.shutdown();
+ }
+
+ // Visible for testing
+ public void pollOnce() {
+ requestThread.doWork();
+ }
+
+ static AbstractRequest.Builder extends AbstractRequest> buildRequest(ApiMessage requestData) {
+ if (requestData instanceof VoteRequestData)
+ return new VoteRequest.Builder((VoteRequestData) requestData);
+ if (requestData instanceof BeginQuorumEpochRequestData)
+ return new BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData);
+ if (requestData instanceof EndQuorumEpochRequestData)
+ return new EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData);
+ if (requestData instanceof FetchRequestData)
+ return new FetchRequest.SimpleBuilder((FetchRequestData) requestData);
+ if (requestData instanceof FetchSnapshotRequestData)
+ return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) requestData);
+ throw new IllegalArgumentException("Unexpected type for requestData: " + requestData);
+ }
+}
diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
index e3482e56751..e527adf6f9b 100644
--- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
+++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
@@ -16,13 +16,11 @@
*/
package org.apache.kafka.raft;
-import java.io.Closeable;
-
/**
* A simple network interface with few assumptions. We do not assume ordering
* of requests or even that every outbound request will receive a response.
*/
-public interface NetworkChannel extends Closeable {
+public interface NetworkChannel extends AutoCloseable {
/**
* Generate a new and unique correlationId for a new request to be sent.
@@ -41,6 +39,6 @@ public interface NetworkChannel extends Closeable {
*/
void updateEndpoint(int id, RaftConfig.InetAddressSpec address);
- default void close() {}
+ default void close() throws InterruptedException {}
}
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java
new file mode 100644
index 00000000000..3a1d097fc7a
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import org.apache.kafka.clients.MockClient;
+import org.apache.kafka.clients.NodeApiVersions;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.message.ApiVersionsResponseData;
+import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
+import org.apache.kafka.common.message.EndQuorumEpochResponseData;
+import org.apache.kafka.common.message.FetchRequestData;
+import org.apache.kafka.common.message.FetchResponseData;
+import org.apache.kafka.common.message.VoteResponseData;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.AbstractResponse;
+import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.BeginQuorumEpochRequest;
+import org.apache.kafka.common.requests.BeginQuorumEpochResponse;
+import org.apache.kafka.common.requests.EndQuorumEpochRequest;
+import org.apache.kafka.common.requests.EndQuorumEpochResponse;
+import org.apache.kafka.common.requests.FetchRequest;
+import org.apache.kafka.common.requests.FetchResponse;
+import org.apache.kafka.common.requests.VoteRequest;
+import org.apache.kafka.common.requests.VoteResponse;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+
+import java.net.InetSocketAddress;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+
+import static java.util.Arrays.asList;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class KafkaNetworkChannelTest {
+
+ private static class StubMetadataUpdater implements MockClient.MockMetadataUpdater {
+
+ @Override
+ public List fetchNodes() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public boolean isUpdateNeeded() {
+ return false;
+ }
+
+ @Override
+ public void update(Time time, MockClient.MetadataUpdate update) { }
+ }
+
+ private static final List RAFT_APIS = asList(
+ ApiKeys.VOTE,
+ ApiKeys.BEGIN_QUORUM_EPOCH,
+ ApiKeys.END_QUORUM_EPOCH,
+ ApiKeys.FETCH
+ );
+
+ private final String clusterId = "clusterId";
+ private final int requestTimeoutMs = 30000;
+ private final Time time = new MockTime();
+ private final MockClient client = new MockClient(time, new StubMetadataUpdater());
+ private final TopicPartition topicPartition = new TopicPartition("topic", 0);
+ private final Uuid topicId = Uuid.randomUuid();
+ private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft");
+
+ @BeforeEach
+ public void setupSupportedApis() {
+ List supportedApis = RAFT_APIS.stream().map(
+ ApiVersionsResponse::toApiVersion).collect(Collectors.toList());
+ client.setNodeApiVersions(NodeApiVersions.create(supportedApis));
+ }
+
+ @Test
+ public void testSendToUnknownDestination() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ assertBrokerNotAvailable(destinationId);
+ }
+
+ @Test
+ public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+ client.backoff(destinationNode, 500);
+ assertBrokerNotAvailable(destinationId);
+ }
+
+ @Test
+ public void testWakeupClientOnSend() throws InterruptedException, ExecutionException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+
+ client.enableBlockingUntilWakeup(1);
+
+ Thread ioThread = new Thread(() -> {
+ // Block in poll until we get the expected wakeup
+ channel.pollOnce();
+
+ // Poll a second time to send request and receive response
+ channel.pollOnce();
+ });
+
+ AbstractResponse response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST));
+ client.prepareResponseFrom(response, destinationNode, false);
+
+ ioThread.start();
+ RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId);
+
+ ioThread.join();
+ assertResponseCompleted(request, Errors.INVALID_REQUEST);
+ }
+
+ @Test
+ public void testSendAndDisconnect() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+
+ for (ApiKeys apiKey : RAFT_APIS) {
+ AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST));
+ client.prepareResponseFrom(response, destinationNode, true);
+ sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
+ }
+ }
+
+ @Test
+ public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+
+ for (ApiKeys apiKey : RAFT_APIS) {
+ client.createPendingAuthenticationError(destinationNode, 100);
+ sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION);
+
+ // reset to clear backoff time
+ client.reset();
+ }
+ }
+
+ private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException {
+ for (ApiKeys apiKey : RAFT_APIS) {
+ sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
+ }
+ }
+
+ @Test
+ public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+
+ for (ApiKeys apiKey : RAFT_APIS) {
+ Errors expectedError = Errors.INVALID_REQUEST;
+ AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError));
+ client.prepareResponseFrom(response, destinationNode);
+ System.out.println("api key " + apiKey + ", response " + response);
+ sendAndAssertErrorResponse(apiKey, destinationId, expectedError);
+ }
+ }
+
+ @Test
+ public void testUnsupportedVersionError() throws ExecutionException, InterruptedException {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+
+ for (ApiKeys apiKey : RAFT_APIS) {
+ client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey);
+ sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION);
+ }
+ }
+
+ @ParameterizedTest
+ @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH)
+ public void testFetchRequestDowngrade(short version) {
+ int destinationId = 2;
+ Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
+ channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec(
+ new InetSocketAddress(destinationNode.host(), destinationNode.port())));
+ sendTestRequest(ApiKeys.FETCH, destinationId);
+ channel.pollOnce();
+
+ assertEquals(1, client.requests().size());
+ AbstractRequest request = client.requests().peek().requestBuilder().build(version);
+
+ if (version < 15) {
+ assertTrue(((FetchRequest) request).data().replicaId() == 1);
+ assertTrue(((FetchRequest) request).data().replicaState().replicaId() == -1);
+ } else {
+ assertTrue(((FetchRequest) request).data().replicaId() == -1);
+ assertTrue(((FetchRequest) request).data().replicaState().replicaId() == 1);
+ }
+ }
+
+ private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) {
+ int correlationId = channel.newCorrelationId();
+ long createdTimeMs = time.milliseconds();
+ ApiMessage apiRequest = buildTestRequest(apiKey);
+ RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs);
+ channel.send(request);
+ return request;
+ }
+
+ private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException {
+ assertTrue(request.completion.isDone());
+
+ RaftResponse.Inbound response = request.completion.get();
+ assertEquals(request.destinationId(), response.sourceId());
+ assertEquals(request.correlationId, response.correlationId);
+ assertEquals(request.data.apiKey(), response.data.apiKey());
+ assertEquals(expectedError, extractError(response.data));
+ }
+
+ private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException {
+ RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId);
+ channel.pollOnce();
+ assertResponseCompleted(request, error);
+ }
+
+ private ApiMessage buildTestRequest(ApiKeys key) {
+ int leaderEpoch = 5;
+ int leaderId = 1;
+ switch (key) {
+ case BEGIN_QUORUM_EPOCH:
+ return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId);
+ case END_QUORUM_EPOCH:
+ return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch,
+ Collections.singletonList(2));
+ case VOTE:
+ int lastEpoch = 4;
+ return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329);
+ case FETCH:
+ FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> {
+ fetchPartition
+ .setCurrentLeaderEpoch(5)
+ .setFetchOffset(333)
+ .setLastFetchedEpoch(5);
+ });
+ request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1));
+ return request;
+ default:
+ throw new AssertionError("Unexpected api " + key);
+ }
+ }
+
+ private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) {
+ switch (key) {
+ case BEGIN_QUORUM_EPOCH:
+ return new BeginQuorumEpochResponseData().setErrorCode(error.code());
+ case END_QUORUM_EPOCH:
+ return new EndQuorumEpochResponseData().setErrorCode(error.code());
+ case VOTE:
+ return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
+ case FETCH:
+ return new FetchResponseData().setErrorCode(error.code());
+ default:
+ throw new AssertionError("Unexpected api " + key);
+ }
+ }
+
+ private Errors extractError(ApiMessage response) {
+ short code;
+ if (response instanceof BeginQuorumEpochResponseData)
+ code = ((BeginQuorumEpochResponseData) response).errorCode();
+ else if (response instanceof EndQuorumEpochResponseData)
+ code = ((EndQuorumEpochResponseData) response).errorCode();
+ else if (response instanceof FetchResponseData)
+ code = ((FetchResponseData) response).errorCode();
+ else if (response instanceof VoteResponseData)
+ code = ((VoteResponseData) response).errorCode();
+ else
+ throw new IllegalArgumentException("Unexpected type for responseData: " + response);
+ return Errors.forCode(code);
+ }
+
+ private AbstractResponse buildResponse(ApiMessage responseData) {
+ if (responseData instanceof VoteResponseData)
+ return new VoteResponse((VoteResponseData) responseData);
+ if (responseData instanceof BeginQuorumEpochResponseData)
+ return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData);
+ if (responseData instanceof EndQuorumEpochResponseData)
+ return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData);
+ if (responseData instanceof FetchResponseData)
+ return new FetchResponse((FetchResponseData) responseData);
+ throw new IllegalArgumentException("Unexpected type for responseData: " + responseData);
+ }
+}