From 69e591db3a7329a8bb984f068596d8658a8618b3 Mon Sep 17 00:00:00 2001 From: Ismael Juma Date: Mon, 16 Oct 2023 20:10:31 -0700 Subject: [PATCH] MINOR: Rewrite/Move KafkaNetworkChannel to the `raft` module (#14559) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is now possible since `InterBrokerSend` was moved from `core` to `server-common`. Also rewrite/move `KafkaNetworkChannelTest`. The scala version of `KafkaNetworkChannelTest` passed with the changes here (before I deleted it). Reviewers: Justine Olshan , José Armando García Sancio --- checkstyle/import-control.xml | 1 + .../kafka/raft/KafkaNetworkChannel.scala | 191 ----------- .../main/scala/kafka/raft/RaftManager.scala | 2 +- .../kafka/raft/KafkaNetworkChannelTest.scala | 316 ----------------- .../kafka/raft/KafkaNetworkChannel.java | 183 ++++++++++ .../org/apache/kafka/raft/NetworkChannel.java | 6 +- .../kafka/raft/KafkaNetworkChannelTest.java | 323 ++++++++++++++++++ 7 files changed, 510 insertions(+), 512 deletions(-) delete mode 100644 core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala delete mode 100644 core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala create mode 100644 raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java create mode 100644 raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml index 888e8a41ae8..42488c3225f 100644 --- a/checkstyle/import-control.xml +++ b/checkstyle/import-control.xml @@ -406,6 +406,7 @@ + diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala deleted file mode 100644 index 7c00961d1dc..00000000000 --- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.raft - -import kafka.utils.Logging -import org.apache.kafka.clients.{ClientResponse, KafkaClient} -import org.apache.kafka.common.Node -import org.apache.kafka.common.message._ -import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} -import org.apache.kafka.common.requests._ -import org.apache.kafka.common.utils.Time -import org.apache.kafka.raft.RaftConfig.InetAddressSpec -import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} -import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} - -import java.util -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable - -object KafkaNetworkChannel { - - private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = { - requestData match { - case voteRequest: VoteRequestData => - new VoteRequest.Builder(voteRequest) - case beginEpochRequest: BeginQuorumEpochRequestData => - new BeginQuorumEpochRequest.Builder(beginEpochRequest) - case endEpochRequest: EndQuorumEpochRequestData => - new EndQuorumEpochRequest.Builder(endEpochRequest) - case fetchRequest: FetchRequestData => - new FetchRequest.SimpleBuilder(fetchRequest) - case fetchSnapshotRequest: FetchSnapshotRequestData => - new FetchSnapshotRequest.Builder(fetchSnapshotRequest) - case _ => - throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData") - } - } - -} - -private[raft] class RaftSendThread( - name: String, - networkClient: KafkaClient, - requestTimeoutMs: Int, - time: Time, - isInterruptible: Boolean = true -) extends InterBrokerSendThread( - name, - networkClient, - requestTimeoutMs, - time, - isInterruptible -) { - private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() - - def generateRequests(): util.Collection[RequestAndCompletionHandler] = { - val list = new util.ArrayList[RequestAndCompletionHandler]() - while (true) { - val request = queue.poll() - if (request == null) { - return list - } else { - list.add(request) - } - } - list - } - - def sendRequest(request: RequestAndCompletionHandler): Unit = { - queue.add(request) - wakeup() - } - -} - - -class KafkaNetworkChannel( - time: Time, - client: KafkaClient, - requestTimeoutMs: Int, - threadNamePrefix: String -) extends NetworkChannel with Logging { - import KafkaNetworkChannel._ - - type ResponseHandler = AbstractResponse => Unit - - private val correlationIdCounter = new AtomicInteger(0) - private val endpoints = mutable.HashMap.empty[Int, Node] - - private val requestThread = new RaftSendThread( - name = threadNamePrefix + "-outbound-request-thread", - networkClient = client, - requestTimeoutMs = requestTimeoutMs, - time = time, - isInterruptible = false - ) - - override def send(request: RaftRequest.Outbound): Unit = { - def completeFuture(message: ApiMessage): Unit = { - val response = new RaftResponse.Inbound( - request.correlationId, - message, - request.destinationId - ) - request.completion.complete(response) - } - - def onComplete(clientResponse: ClientResponse): Unit = { - val response = if (clientResponse.versionMismatch != null) { - error(s"Request $request failed due to unsupported version error", - clientResponse.versionMismatch) - errorResponse(request.data, Errors.UNSUPPORTED_VERSION) - } else if (clientResponse.authenticationException != null) { - // For now we treat authentication errors as retriable. We use the - // `NETWORK_EXCEPTION` error code for lack of a good alternative. - // Note that `NodeToControllerChannelManager` will still log the - // authentication errors so that users have a chance to fix the problem. - error(s"Request $request failed due to authentication error", - clientResponse.authenticationException) - errorResponse(request.data, Errors.NETWORK_EXCEPTION) - } else if (clientResponse.wasDisconnected()) { - errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) - } else { - clientResponse.responseBody.data - } - completeFuture(response) - } - - endpoints.get(request.destinationId) match { - case Some(node) => - requestThread.sendRequest(new RequestAndCompletionHandler( - request.createdTimeMs, - node, - buildRequest(request.data), - onComplete - )) - - case None => - completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) - } - } - - // Visible for testing - private[raft] def pollOnce(): Unit = { - requestThread.doWork() - } - - override def newCorrelationId(): Int = { - correlationIdCounter.getAndIncrement() - } - - private def errorResponse( - request: ApiMessage, - error: Errors - ): ApiMessage = { - val apiKey = ApiKeys.forId(request.apiKey) - RaftUtil.errorResponse(apiKey, error) - } - - override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = { - val node = new Node(id, spec.address.getHostString, spec.address.getPort) - endpoints.put(id, node) - } - - def start(): Unit = { - requestThread.start() - } - - def initiateShutdown(): Unit = { - requestThread.initiateShutdown() - } - - override def close(): Unit = { - requestThread.shutdown() - } -} diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala index 020477d5a42..f9311d20d95 100644 --- a/core/src/main/scala/kafka/raft/RaftManager.scala +++ b/core/src/main/scala/kafka/raft/RaftManager.scala @@ -42,7 +42,7 @@ import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.security.auth.SecurityProtocol import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, NON_ROUTABLE_ADDRESS, UnknownAddressSpec} -import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog} +import org.apache.kafka.raft.{FileBasedStateStore, KafkaNetworkChannel, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog} import org.apache.kafka.server.common.serialization.RecordSerde import org.apache.kafka.server.util.{KafkaScheduler, ShutdownableThread} import org.apache.kafka.server.fault.FaultHandler diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala deleted file mode 100644 index af230f66553..00000000000 --- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.raft - -import java.net.InetSocketAddress -import java.util -import java.util.Collections -import org.apache.kafka.clients.MockClient.MockMetadataUpdater -import org.apache.kafka.clients.{MockClient, NodeApiVersions} -import org.apache.kafka.common.message.FetchRequestData.ReplicaState -import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} -import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} -import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, VoteRequest, VoteResponse} -import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource -import org.apache.kafka.common.utils.{MockTime, Time} -import org.apache.kafka.common.{Node, TopicPartition, Uuid} -import org.apache.kafka.raft.RaftConfig.InetAddressSpec -import org.apache.kafka.raft.{RaftRequest, RaftUtil} -import org.junit.jupiter.api.Assertions._ -import org.junit.jupiter.api.{BeforeEach, Test} -import org.junit.jupiter.params.ParameterizedTest - -import scala.jdk.CollectionConverters._ - -class KafkaNetworkChannelTest { - import KafkaNetworkChannelTest._ - - private val clusterId = "clusterId" - private val requestTimeoutMs = 30000 - private val time = new MockTime() - private val client = new MockClient(time, new StubMetadataUpdater) - private val topicPartition = new TopicPartition("topic", 0) - private val topicId = Uuid.randomUuid() - private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft") - - @BeforeEach - def setupSupportedApis(): Unit = { - val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion) - client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava)) - } - - @Test - def testSendToUnknownDestination(): Unit = { - val destinationId = 2 - assertBrokerNotAvailable(destinationId) - } - - @Test - def testSendToBlackedOutDestination(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - client.backoff(destinationNode, 500) - assertBrokerNotAvailable(destinationId) - } - - @Test - def testWakeupClientOnSend(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - - client.enableBlockingUntilWakeup(1) - - val ioThread = new Thread() { - override def run(): Unit = { - // Block in poll until we get the expected wakeup - channel.pollOnce() - - // Poll a second time to send request and receive response - channel.pollOnce() - } - } - - val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)) - client.prepareResponseFrom(response, destinationNode, false) - - ioThread.start() - val request = sendTestRequest(ApiKeys.FETCH, destinationId) - - ioThread.join() - assertResponseCompleted(request, Errors.INVALID_REQUEST) - } - - @Test - def testSendAndDisconnect(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - - for (apiKey <- RaftApis) { - val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) - client.prepareResponseFrom(response, destinationNode, true) - sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) - } - } - - @Test - def testSendAndFailAuthentication(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - - for (apiKey <- RaftApis) { - client.createPendingAuthenticationError(destinationNode, 100) - sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION) - - // reset to clear backoff time - client.reset() - } - } - - private def assertBrokerNotAvailable(destinationId: Int): Unit = { - for (apiKey <- RaftApis) { - sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) - } - } - - @Test - def testSendAndReceiveOutboundRequest(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - - for (apiKey <- RaftApis) { - val expectedError = Errors.INVALID_REQUEST - val response = buildResponse(buildTestErrorResponse(apiKey, expectedError)) - client.prepareResponseFrom(response, destinationNode) - sendAndAssertErrorResponse(apiKey, destinationId, expectedError) - } - } - - @Test - def testUnsupportedVersionError(): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - - for (apiKey <- RaftApis) { - client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey) - sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION) - } - } - - @ParameterizedTest - @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) - def testFetchRequestDowngrade(version: Short): Unit = { - val destinationId = 2 - val destinationNode = new Node(destinationId, "127.0.0.1", 9092) - channel.updateEndpoint(destinationId, new InetAddressSpec( - new InetSocketAddress(destinationNode.host, destinationNode.port))) - sendTestRequest(ApiKeys.FETCH, destinationId) - channel.pollOnce() - - assertEquals(1, client.requests.size) - val request = client.requests.peek.requestBuilder.build(version) - - if (version < 15) { - assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1) - assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1) - } else { - assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1) - assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1) - } - } - - private def sendTestRequest( - apiKey: ApiKeys, - destinationId: Int, - ): RaftRequest.Outbound = { - val correlationId = channel.newCorrelationId() - val createdTimeMs = time.milliseconds() - val apiRequest = buildTestRequest(apiKey) - val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs) - channel.send(request) - request - } - - private def assertResponseCompleted( - request: RaftRequest.Outbound, - expectedError: Errors - ): Unit = { - assertTrue(request.completion.isDone) - - val response = request.completion.get() - assertEquals(request.destinationId, response.sourceId) - assertEquals(request.correlationId, response.correlationId) - assertEquals(request.data.apiKey, response.data.apiKey) - assertEquals(expectedError, extractError(response.data)) - } - - private def sendAndAssertErrorResponse( - apiKey: ApiKeys, - destinationId: Int, - error: Errors - ): Unit = { - val request = sendTestRequest(apiKey, destinationId) - channel.pollOnce() - assertResponseCompleted(request, error) - } - - private def buildTestRequest(key: ApiKeys): ApiMessage = { - val leaderEpoch = 5 - val leaderId = 1 - key match { - case ApiKeys.BEGIN_QUORUM_EPOCH => - BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId) - - case ApiKeys.END_QUORUM_EPOCH => - EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, - leaderEpoch, Collections.singletonList(2)) - - case ApiKeys.VOTE => - val lastEpoch = 4 - VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329) - - case ApiKeys.FETCH => - val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => { - fetchPartition - .setCurrentLeaderEpoch(5) - .setFetchOffset(333) - .setLastFetchedEpoch(5) - }) - request.setReplicaState(new ReplicaState().setReplicaId(1)) - - case _ => - throw new AssertionError(s"Unexpected api $key") - } - } - - private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = { - key match { - case ApiKeys.BEGIN_QUORUM_EPOCH => - new BeginQuorumEpochResponseData() - .setErrorCode(error.code) - - case ApiKeys.END_QUORUM_EPOCH => - new EndQuorumEpochResponseData() - .setErrorCode(error.code) - - case ApiKeys.VOTE => - VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); - - case ApiKeys.FETCH => - new FetchResponseData() - .setErrorCode(error.code) - - case _ => - throw new AssertionError(s"Unexpected api $key") - } - } - - private def extractError(response: ApiMessage): Errors = { - val code = (response: @unchecked) match { - case res: BeginQuorumEpochResponseData => res.errorCode - case res: EndQuorumEpochResponseData => res.errorCode - case res: FetchResponseData => res.errorCode - case res: VoteResponseData => res.errorCode - } - Errors.forCode(code) - } - - - def buildResponse(responseData: ApiMessage): AbstractResponse = { - responseData match { - case voteResponse: VoteResponseData => - new VoteResponse(voteResponse) - case beginEpochResponse: BeginQuorumEpochResponseData => - new BeginQuorumEpochResponse(beginEpochResponse) - case endEpochResponse: EndQuorumEpochResponseData => - new EndQuorumEpochResponse(endEpochResponse) - case fetchResponse: FetchResponseData => - new FetchResponse(fetchResponse) - case _ => - throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") - } - } - -} - -object KafkaNetworkChannelTest { - val RaftApis = Seq( - ApiKeys.VOTE, - ApiKeys.BEGIN_QUORUM_EPOCH, - ApiKeys.END_QUORUM_EPOCH, - ApiKeys.FETCH, - ) - - private class StubMetadataUpdater extends MockMetadataUpdater { - override def fetchNodes(): util.List[Node] = Collections.emptyList() - - override def isUpdateNeeded: Boolean = false - - override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {} - } -} diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java new file mode 100644 index 00000000000..2c0dd25d439 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.BeginQuorumEpochRequestData; +import org.apache.kafka.common.message.EndQuorumEpochRequestData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchSnapshotRequestData; +import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.BeginQuorumEpochRequest; +import org.apache.kafka.common.requests.EndQuorumEpochRequest; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchSnapshotRequest; +import org.apache.kafka.common.requests.VoteRequest; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.util.InterBrokerSendThread; +import org.apache.kafka.server.util.RequestAndCompletionHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; + +public class KafkaNetworkChannel implements NetworkChannel { + + static class SendThread extends InterBrokerSendThread { + + private Queue queue = new ConcurrentLinkedQueue<>(); + + public SendThread(String name, KafkaClient networkClient, int requestTimeoutMs, Time time, boolean isInterruptible) { + super(name, networkClient, requestTimeoutMs, time, isInterruptible); + } + + @Override + public Collection generateRequests() { + List list = new ArrayList<>(); + while (true) { + RequestAndCompletionHandler request = queue.poll(); + if (request == null) { + return list; + } else { + list.add(request); + } + } + } + + public void sendRequest(RequestAndCompletionHandler request) { + queue.add(request); + wakeup(); + } + } + + private static final Logger log = LoggerFactory.getLogger(KafkaNetworkChannel.class); + + private final SendThread requestThread; + + private final AtomicInteger correlationIdCounter = new AtomicInteger(0); + private final Map endpoints = new HashMap<>(); + + public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) { + this.requestThread = new SendThread( + threadNamePrefix + "-outbound-request-thread", + client, + requestTimeoutMs, + time, + false + ); + } + + @Override + public int newCorrelationId() { + return correlationIdCounter.getAndIncrement(); + } + + @Override + public void send(RaftRequest.Outbound request) { + Node node = endpoints.get(request.destinationId()); + if (node != null) { + requestThread.sendRequest(new RequestAndCompletionHandler( + request.createdTimeMs, + node, + buildRequest(request.data), + response -> sendOnComplete(request, response) + )); + } else + sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)); + } + + private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) { + RaftResponse.Inbound response = new RaftResponse.Inbound( + request.correlationId, + message, + request.destinationId() + ); + request.completion.complete(response); + } + + private void sendOnComplete(RaftRequest.Outbound request, ClientResponse clientResponse) { + ApiMessage response; + if (clientResponse.versionMismatch() != null) { + log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch()); + response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION); + } else if (clientResponse.authenticationException() != null) { + // For now we treat authentication errors as retriable. We use the + // `NETWORK_EXCEPTION` error code for lack of a good alternative. + // Note that `NodeToControllerChannelManager` will still log the + // authentication errors so that users have a chance to fix the problem. + log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException()); + response = errorResponse(request.data, Errors.NETWORK_EXCEPTION); + } else if (clientResponse.wasDisconnected()) { + response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE); + } else { + response = clientResponse.responseBody().data(); + } + sendCompleteFuture(request, response); + } + + private ApiMessage errorResponse(ApiMessage request, Errors error) { + ApiKeys apiKey = ApiKeys.forId(request.apiKey()); + return RaftUtil.errorResponse(apiKey, error); + } + + @Override + public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) { + Node node = new Node(id, spec.address.getHostString(), spec.address.getPort()); + endpoints.put(id, node); + } + + public void start() { + requestThread.start(); + } + + @Override + public void close() throws InterruptedException { + requestThread.shutdown(); + } + + // Visible for testing + public void pollOnce() { + requestThread.doWork(); + } + + static AbstractRequest.Builder buildRequest(ApiMessage requestData) { + if (requestData instanceof VoteRequestData) + return new VoteRequest.Builder((VoteRequestData) requestData); + if (requestData instanceof BeginQuorumEpochRequestData) + return new BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData); + if (requestData instanceof EndQuorumEpochRequestData) + return new EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData); + if (requestData instanceof FetchRequestData) + return new FetchRequest.SimpleBuilder((FetchRequestData) requestData); + if (requestData instanceof FetchSnapshotRequestData) + return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) requestData); + throw new IllegalArgumentException("Unexpected type for requestData: " + requestData); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java index e3482e56751..e527adf6f9b 100644 --- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java +++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java @@ -16,13 +16,11 @@ */ package org.apache.kafka.raft; -import java.io.Closeable; - /** * A simple network interface with few assumptions. We do not assume ordering * of requests or even that every outbound request will receive a response. */ -public interface NetworkChannel extends Closeable { +public interface NetworkChannel extends AutoCloseable { /** * Generate a new and unique correlationId for a new request to be sent. @@ -41,6 +39,6 @@ public interface NetworkChannel extends Closeable { */ void updateEndpoint(int id, RaftConfig.InetAddressSpec address); - default void close() {} + default void close() throws InterruptedException {} } diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java new file mode 100644 index 00000000000..3a1d097fc7a --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.BeginQuorumEpochRequest; +import org.apache.kafka.common.requests.BeginQuorumEpochResponse; +import org.apache.kafka.common.requests.EndQuorumEpochRequest; +import org.apache.kafka.common.requests.EndQuorumEpochResponse; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.VoteRequest; +import org.apache.kafka.common.requests.VoteResponse; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KafkaNetworkChannelTest { + + private static class StubMetadataUpdater implements MockClient.MockMetadataUpdater { + + @Override + public List fetchNodes() { + return Collections.emptyList(); + } + + @Override + public boolean isUpdateNeeded() { + return false; + } + + @Override + public void update(Time time, MockClient.MetadataUpdate update) { } + } + + private static final List RAFT_APIS = asList( + ApiKeys.VOTE, + ApiKeys.BEGIN_QUORUM_EPOCH, + ApiKeys.END_QUORUM_EPOCH, + ApiKeys.FETCH + ); + + private final String clusterId = "clusterId"; + private final int requestTimeoutMs = 30000; + private final Time time = new MockTime(); + private final MockClient client = new MockClient(time, new StubMetadataUpdater()); + private final TopicPartition topicPartition = new TopicPartition("topic", 0); + private final Uuid topicId = Uuid.randomUuid(); + private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft"); + + @BeforeEach + public void setupSupportedApis() { + List supportedApis = RAFT_APIS.stream().map( + ApiVersionsResponse::toApiVersion).collect(Collectors.toList()); + client.setNodeApiVersions(NodeApiVersions.create(supportedApis)); + } + + @Test + public void testSendToUnknownDestination() throws ExecutionException, InterruptedException { + int destinationId = 2; + assertBrokerNotAvailable(destinationId); + } + + @Test + public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + client.backoff(destinationNode, 500); + assertBrokerNotAvailable(destinationId); + } + + @Test + public void testWakeupClientOnSend() throws InterruptedException, ExecutionException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + + client.enableBlockingUntilWakeup(1); + + Thread ioThread = new Thread(() -> { + // Block in poll until we get the expected wakeup + channel.pollOnce(); + + // Poll a second time to send request and receive response + channel.pollOnce(); + }); + + AbstractResponse response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)); + client.prepareResponseFrom(response, destinationNode, false); + + ioThread.start(); + RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId); + + ioThread.join(); + assertResponseCompleted(request, Errors.INVALID_REQUEST); + } + + @Test + public void testSendAndDisconnect() throws ExecutionException, InterruptedException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + + for (ApiKeys apiKey : RAFT_APIS) { + AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)); + client.prepareResponseFrom(response, destinationNode, true); + sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); + } + } + + @Test + public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + + for (ApiKeys apiKey : RAFT_APIS) { + client.createPendingAuthenticationError(destinationNode, 100); + sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION); + + // reset to clear backoff time + client.reset(); + } + } + + private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException { + for (ApiKeys apiKey : RAFT_APIS) { + sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); + } + } + + @Test + public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + + for (ApiKeys apiKey : RAFT_APIS) { + Errors expectedError = Errors.INVALID_REQUEST; + AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError)); + client.prepareResponseFrom(response, destinationNode); + System.out.println("api key " + apiKey + ", response " + response); + sendAndAssertErrorResponse(apiKey, destinationId, expectedError); + } + } + + @Test + public void testUnsupportedVersionError() throws ExecutionException, InterruptedException { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + + for (ApiKeys apiKey : RAFT_APIS) { + client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey); + sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION); + } + } + + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) + public void testFetchRequestDowngrade(short version) { + int destinationId = 2; + Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); + channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( + new InetSocketAddress(destinationNode.host(), destinationNode.port()))); + sendTestRequest(ApiKeys.FETCH, destinationId); + channel.pollOnce(); + + assertEquals(1, client.requests().size()); + AbstractRequest request = client.requests().peek().requestBuilder().build(version); + + if (version < 15) { + assertTrue(((FetchRequest) request).data().replicaId() == 1); + assertTrue(((FetchRequest) request).data().replicaState().replicaId() == -1); + } else { + assertTrue(((FetchRequest) request).data().replicaId() == -1); + assertTrue(((FetchRequest) request).data().replicaState().replicaId() == 1); + } + } + + private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) { + int correlationId = channel.newCorrelationId(); + long createdTimeMs = time.milliseconds(); + ApiMessage apiRequest = buildTestRequest(apiKey); + RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs); + channel.send(request); + return request; + } + + private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException { + assertTrue(request.completion.isDone()); + + RaftResponse.Inbound response = request.completion.get(); + assertEquals(request.destinationId(), response.sourceId()); + assertEquals(request.correlationId, response.correlationId); + assertEquals(request.data.apiKey(), response.data.apiKey()); + assertEquals(expectedError, extractError(response.data)); + } + + private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException { + RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId); + channel.pollOnce(); + assertResponseCompleted(request, error); + } + + private ApiMessage buildTestRequest(ApiKeys key) { + int leaderEpoch = 5; + int leaderId = 1; + switch (key) { + case BEGIN_QUORUM_EPOCH: + return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId); + case END_QUORUM_EPOCH: + return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch, + Collections.singletonList(2)); + case VOTE: + int lastEpoch = 4; + return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329); + case FETCH: + FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> { + fetchPartition + .setCurrentLeaderEpoch(5) + .setFetchOffset(333) + .setLastFetchedEpoch(5); + }); + request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1)); + return request; + default: + throw new AssertionError("Unexpected api " + key); + } + } + + private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) { + switch (key) { + case BEGIN_QUORUM_EPOCH: + return new BeginQuorumEpochResponseData().setErrorCode(error.code()); + case END_QUORUM_EPOCH: + return new EndQuorumEpochResponseData().setErrorCode(error.code()); + case VOTE: + return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); + case FETCH: + return new FetchResponseData().setErrorCode(error.code()); + default: + throw new AssertionError("Unexpected api " + key); + } + } + + private Errors extractError(ApiMessage response) { + short code; + if (response instanceof BeginQuorumEpochResponseData) + code = ((BeginQuorumEpochResponseData) response).errorCode(); + else if (response instanceof EndQuorumEpochResponseData) + code = ((EndQuorumEpochResponseData) response).errorCode(); + else if (response instanceof FetchResponseData) + code = ((FetchResponseData) response).errorCode(); + else if (response instanceof VoteResponseData) + code = ((VoteResponseData) response).errorCode(); + else + throw new IllegalArgumentException("Unexpected type for responseData: " + response); + return Errors.forCode(code); + } + + private AbstractResponse buildResponse(ApiMessage responseData) { + if (responseData instanceof VoteResponseData) + return new VoteResponse((VoteResponseData) responseData); + if (responseData instanceof BeginQuorumEpochResponseData) + return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData); + if (responseData instanceof EndQuorumEpochResponseData) + return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData); + if (responseData instanceof FetchResponseData) + return new FetchResponse((FetchResponseData) responseData); + throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); + } +}