Browse Source
This is now possible since `InterBrokerSend` was moved from `core` to `server-common`. Also rewrite/move `KafkaNetworkChannelTest`. The scala version of `KafkaNetworkChannelTest` passed with the changes here (before I deleted it). Reviewers: Justine Olshan <jolshan@confluent.io>, José Armando García Sancio <jsancio@users.noreply.github.com>pull/14569/head
Ismael Juma
11 months ago
committed by
GitHub
7 changed files with 510 additions and 512 deletions
@ -1,191 +0,0 @@
@@ -1,191 +0,0 @@
|
||||
/* |
||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||
* contributor license agreements. See the NOTICE file distributed with |
||||
* this work for additional information regarding copyright ownership. |
||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||
* (the "License"); you may not use this file except in compliance with |
||||
* the License. You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0 |
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package kafka.raft |
||||
|
||||
import kafka.utils.Logging |
||||
import org.apache.kafka.clients.{ClientResponse, KafkaClient} |
||||
import org.apache.kafka.common.Node |
||||
import org.apache.kafka.common.message._ |
||||
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} |
||||
import org.apache.kafka.common.requests._ |
||||
import org.apache.kafka.common.utils.Time |
||||
import org.apache.kafka.raft.RaftConfig.InetAddressSpec |
||||
import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} |
||||
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} |
||||
|
||||
import java.util |
||||
import java.util.concurrent.ConcurrentLinkedQueue |
||||
import java.util.concurrent.atomic.AtomicInteger |
||||
import scala.collection.mutable |
||||
|
||||
object KafkaNetworkChannel { |
||||
|
||||
private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = { |
||||
requestData match { |
||||
case voteRequest: VoteRequestData => |
||||
new VoteRequest.Builder(voteRequest) |
||||
case beginEpochRequest: BeginQuorumEpochRequestData => |
||||
new BeginQuorumEpochRequest.Builder(beginEpochRequest) |
||||
case endEpochRequest: EndQuorumEpochRequestData => |
||||
new EndQuorumEpochRequest.Builder(endEpochRequest) |
||||
case fetchRequest: FetchRequestData => |
||||
new FetchRequest.SimpleBuilder(fetchRequest) |
||||
case fetchSnapshotRequest: FetchSnapshotRequestData => |
||||
new FetchSnapshotRequest.Builder(fetchSnapshotRequest) |
||||
case _ => |
||||
throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData") |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
private[raft] class RaftSendThread( |
||||
name: String, |
||||
networkClient: KafkaClient, |
||||
requestTimeoutMs: Int, |
||||
time: Time, |
||||
isInterruptible: Boolean = true |
||||
) extends InterBrokerSendThread( |
||||
name, |
||||
networkClient, |
||||
requestTimeoutMs, |
||||
time, |
||||
isInterruptible |
||||
) { |
||||
private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() |
||||
|
||||
def generateRequests(): util.Collection[RequestAndCompletionHandler] = { |
||||
val list = new util.ArrayList[RequestAndCompletionHandler]() |
||||
while (true) { |
||||
val request = queue.poll() |
||||
if (request == null) { |
||||
return list |
||||
} else { |
||||
list.add(request) |
||||
} |
||||
} |
||||
list |
||||
} |
||||
|
||||
def sendRequest(request: RequestAndCompletionHandler): Unit = { |
||||
queue.add(request) |
||||
wakeup() |
||||
} |
||||
|
||||
} |
||||
|
||||
|
||||
class KafkaNetworkChannel( |
||||
time: Time, |
||||
client: KafkaClient, |
||||
requestTimeoutMs: Int, |
||||
threadNamePrefix: String |
||||
) extends NetworkChannel with Logging { |
||||
import KafkaNetworkChannel._ |
||||
|
||||
type ResponseHandler = AbstractResponse => Unit |
||||
|
||||
private val correlationIdCounter = new AtomicInteger(0) |
||||
private val endpoints = mutable.HashMap.empty[Int, Node] |
||||
|
||||
private val requestThread = new RaftSendThread( |
||||
name = threadNamePrefix + "-outbound-request-thread", |
||||
networkClient = client, |
||||
requestTimeoutMs = requestTimeoutMs, |
||||
time = time, |
||||
isInterruptible = false |
||||
) |
||||
|
||||
override def send(request: RaftRequest.Outbound): Unit = { |
||||
def completeFuture(message: ApiMessage): Unit = { |
||||
val response = new RaftResponse.Inbound( |
||||
request.correlationId, |
||||
message, |
||||
request.destinationId |
||||
) |
||||
request.completion.complete(response) |
||||
} |
||||
|
||||
def onComplete(clientResponse: ClientResponse): Unit = { |
||||
val response = if (clientResponse.versionMismatch != null) { |
||||
error(s"Request $request failed due to unsupported version error", |
||||
clientResponse.versionMismatch) |
||||
errorResponse(request.data, Errors.UNSUPPORTED_VERSION) |
||||
} else if (clientResponse.authenticationException != null) { |
||||
// For now we treat authentication errors as retriable. We use the |
||||
// `NETWORK_EXCEPTION` error code for lack of a good alternative. |
||||
// Note that `NodeToControllerChannelManager` will still log the |
||||
// authentication errors so that users have a chance to fix the problem. |
||||
error(s"Request $request failed due to authentication error", |
||||
clientResponse.authenticationException) |
||||
errorResponse(request.data, Errors.NETWORK_EXCEPTION) |
||||
} else if (clientResponse.wasDisconnected()) { |
||||
errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) |
||||
} else { |
||||
clientResponse.responseBody.data |
||||
} |
||||
completeFuture(response) |
||||
} |
||||
|
||||
endpoints.get(request.destinationId) match { |
||||
case Some(node) => |
||||
requestThread.sendRequest(new RequestAndCompletionHandler( |
||||
request.createdTimeMs, |
||||
node, |
||||
buildRequest(request.data), |
||||
onComplete |
||||
)) |
||||
|
||||
case None => |
||||
completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) |
||||
} |
||||
} |
||||
|
||||
// Visible for testing |
||||
private[raft] def pollOnce(): Unit = { |
||||
requestThread.doWork() |
||||
} |
||||
|
||||
override def newCorrelationId(): Int = { |
||||
correlationIdCounter.getAndIncrement() |
||||
} |
||||
|
||||
private def errorResponse( |
||||
request: ApiMessage, |
||||
error: Errors |
||||
): ApiMessage = { |
||||
val apiKey = ApiKeys.forId(request.apiKey) |
||||
RaftUtil.errorResponse(apiKey, error) |
||||
} |
||||
|
||||
override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = { |
||||
val node = new Node(id, spec.address.getHostString, spec.address.getPort) |
||||
endpoints.put(id, node) |
||||
} |
||||
|
||||
def start(): Unit = { |
||||
requestThread.start() |
||||
} |
||||
|
||||
def initiateShutdown(): Unit = { |
||||
requestThread.initiateShutdown() |
||||
} |
||||
|
||||
override def close(): Unit = { |
||||
requestThread.shutdown() |
||||
} |
||||
} |
@ -1,316 +0,0 @@
@@ -1,316 +0,0 @@
|
||||
/* |
||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||
* contributor license agreements. See the NOTICE file distributed with |
||||
* this work for additional information regarding copyright ownership. |
||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||
* (the "License"); you may not use this file except in compliance with |
||||
* the License. You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0 |
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package kafka.raft |
||||
|
||||
import java.net.InetSocketAddress |
||||
import java.util |
||||
import java.util.Collections |
||||
import org.apache.kafka.clients.MockClient.MockMetadataUpdater |
||||
import org.apache.kafka.clients.{MockClient, NodeApiVersions} |
||||
import org.apache.kafka.common.message.FetchRequestData.ReplicaState |
||||
import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} |
||||
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} |
||||
import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, VoteRequest, VoteResponse} |
||||
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource |
||||
import org.apache.kafka.common.utils.{MockTime, Time} |
||||
import org.apache.kafka.common.{Node, TopicPartition, Uuid} |
||||
import org.apache.kafka.raft.RaftConfig.InetAddressSpec |
||||
import org.apache.kafka.raft.{RaftRequest, RaftUtil} |
||||
import org.junit.jupiter.api.Assertions._ |
||||
import org.junit.jupiter.api.{BeforeEach, Test} |
||||
import org.junit.jupiter.params.ParameterizedTest |
||||
|
||||
import scala.jdk.CollectionConverters._ |
||||
|
||||
class KafkaNetworkChannelTest { |
||||
import KafkaNetworkChannelTest._ |
||||
|
||||
private val clusterId = "clusterId" |
||||
private val requestTimeoutMs = 30000 |
||||
private val time = new MockTime() |
||||
private val client = new MockClient(time, new StubMetadataUpdater) |
||||
private val topicPartition = new TopicPartition("topic", 0) |
||||
private val topicId = Uuid.randomUuid() |
||||
private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft") |
||||
|
||||
@BeforeEach |
||||
def setupSupportedApis(): Unit = { |
||||
val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion) |
||||
client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava)) |
||||
} |
||||
|
||||
@Test |
||||
def testSendToUnknownDestination(): Unit = { |
||||
val destinationId = 2 |
||||
assertBrokerNotAvailable(destinationId) |
||||
} |
||||
|
||||
@Test |
||||
def testSendToBlackedOutDestination(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
client.backoff(destinationNode, 500) |
||||
assertBrokerNotAvailable(destinationId) |
||||
} |
||||
|
||||
@Test |
||||
def testWakeupClientOnSend(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
|
||||
client.enableBlockingUntilWakeup(1) |
||||
|
||||
val ioThread = new Thread() { |
||||
override def run(): Unit = { |
||||
// Block in poll until we get the expected wakeup |
||||
channel.pollOnce() |
||||
|
||||
// Poll a second time to send request and receive response |
||||
channel.pollOnce() |
||||
} |
||||
} |
||||
|
||||
val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)) |
||||
client.prepareResponseFrom(response, destinationNode, false) |
||||
|
||||
ioThread.start() |
||||
val request = sendTestRequest(ApiKeys.FETCH, destinationId) |
||||
|
||||
ioThread.join() |
||||
assertResponseCompleted(request, Errors.INVALID_REQUEST) |
||||
} |
||||
|
||||
@Test |
||||
def testSendAndDisconnect(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
|
||||
for (apiKey <- RaftApis) { |
||||
val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) |
||||
client.prepareResponseFrom(response, destinationNode, true) |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
def testSendAndFailAuthentication(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
|
||||
for (apiKey <- RaftApis) { |
||||
client.createPendingAuthenticationError(destinationNode, 100) |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION) |
||||
|
||||
// reset to clear backoff time |
||||
client.reset() |
||||
} |
||||
} |
||||
|
||||
private def assertBrokerNotAvailable(destinationId: Int): Unit = { |
||||
for (apiKey <- RaftApis) { |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
def testSendAndReceiveOutboundRequest(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
|
||||
for (apiKey <- RaftApis) { |
||||
val expectedError = Errors.INVALID_REQUEST |
||||
val response = buildResponse(buildTestErrorResponse(apiKey, expectedError)) |
||||
client.prepareResponseFrom(response, destinationNode) |
||||
sendAndAssertErrorResponse(apiKey, destinationId, expectedError) |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
def testUnsupportedVersionError(): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
|
||||
for (apiKey <- RaftApis) { |
||||
client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey) |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION) |
||||
} |
||||
} |
||||
|
||||
@ParameterizedTest |
||||
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) |
||||
def testFetchRequestDowngrade(version: Short): Unit = { |
||||
val destinationId = 2 |
||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
||||
sendTestRequest(ApiKeys.FETCH, destinationId) |
||||
channel.pollOnce() |
||||
|
||||
assertEquals(1, client.requests.size) |
||||
val request = client.requests.peek.requestBuilder.build(version) |
||||
|
||||
if (version < 15) { |
||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1) |
||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1) |
||||
} else { |
||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1) |
||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1) |
||||
} |
||||
} |
||||
|
||||
private def sendTestRequest( |
||||
apiKey: ApiKeys, |
||||
destinationId: Int, |
||||
): RaftRequest.Outbound = { |
||||
val correlationId = channel.newCorrelationId() |
||||
val createdTimeMs = time.milliseconds() |
||||
val apiRequest = buildTestRequest(apiKey) |
||||
val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs) |
||||
channel.send(request) |
||||
request |
||||
} |
||||
|
||||
private def assertResponseCompleted( |
||||
request: RaftRequest.Outbound, |
||||
expectedError: Errors |
||||
): Unit = { |
||||
assertTrue(request.completion.isDone) |
||||
|
||||
val response = request.completion.get() |
||||
assertEquals(request.destinationId, response.sourceId) |
||||
assertEquals(request.correlationId, response.correlationId) |
||||
assertEquals(request.data.apiKey, response.data.apiKey) |
||||
assertEquals(expectedError, extractError(response.data)) |
||||
} |
||||
|
||||
private def sendAndAssertErrorResponse( |
||||
apiKey: ApiKeys, |
||||
destinationId: Int, |
||||
error: Errors |
||||
): Unit = { |
||||
val request = sendTestRequest(apiKey, destinationId) |
||||
channel.pollOnce() |
||||
assertResponseCompleted(request, error) |
||||
} |
||||
|
||||
private def buildTestRequest(key: ApiKeys): ApiMessage = { |
||||
val leaderEpoch = 5 |
||||
val leaderId = 1 |
||||
key match { |
||||
case ApiKeys.BEGIN_QUORUM_EPOCH => |
||||
BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId) |
||||
|
||||
case ApiKeys.END_QUORUM_EPOCH => |
||||
EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, |
||||
leaderEpoch, Collections.singletonList(2)) |
||||
|
||||
case ApiKeys.VOTE => |
||||
val lastEpoch = 4 |
||||
VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329) |
||||
|
||||
case ApiKeys.FETCH => |
||||
val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => { |
||||
fetchPartition |
||||
.setCurrentLeaderEpoch(5) |
||||
.setFetchOffset(333) |
||||
.setLastFetchedEpoch(5) |
||||
}) |
||||
request.setReplicaState(new ReplicaState().setReplicaId(1)) |
||||
|
||||
case _ => |
||||
throw new AssertionError(s"Unexpected api $key") |
||||
} |
||||
} |
||||
|
||||
private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = { |
||||
key match { |
||||
case ApiKeys.BEGIN_QUORUM_EPOCH => |
||||
new BeginQuorumEpochResponseData() |
||||
.setErrorCode(error.code) |
||||
|
||||
case ApiKeys.END_QUORUM_EPOCH => |
||||
new EndQuorumEpochResponseData() |
||||
.setErrorCode(error.code) |
||||
|
||||
case ApiKeys.VOTE => |
||||
VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); |
||||
|
||||
case ApiKeys.FETCH => |
||||
new FetchResponseData() |
||||
.setErrorCode(error.code) |
||||
|
||||
case _ => |
||||
throw new AssertionError(s"Unexpected api $key") |
||||
} |
||||
} |
||||
|
||||
private def extractError(response: ApiMessage): Errors = { |
||||
val code = (response: @unchecked) match { |
||||
case res: BeginQuorumEpochResponseData => res.errorCode |
||||
case res: EndQuorumEpochResponseData => res.errorCode |
||||
case res: FetchResponseData => res.errorCode |
||||
case res: VoteResponseData => res.errorCode |
||||
} |
||||
Errors.forCode(code) |
||||
} |
||||
|
||||
|
||||
def buildResponse(responseData: ApiMessage): AbstractResponse = { |
||||
responseData match { |
||||
case voteResponse: VoteResponseData => |
||||
new VoteResponse(voteResponse) |
||||
case beginEpochResponse: BeginQuorumEpochResponseData => |
||||
new BeginQuorumEpochResponse(beginEpochResponse) |
||||
case endEpochResponse: EndQuorumEpochResponseData => |
||||
new EndQuorumEpochResponse(endEpochResponse) |
||||
case fetchResponse: FetchResponseData => |
||||
new FetchResponse(fetchResponse) |
||||
case _ => |
||||
throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") |
||||
} |
||||
} |
||||
|
||||
} |
||||
|
||||
object KafkaNetworkChannelTest { |
||||
val RaftApis = Seq( |
||||
ApiKeys.VOTE, |
||||
ApiKeys.BEGIN_QUORUM_EPOCH, |
||||
ApiKeys.END_QUORUM_EPOCH, |
||||
ApiKeys.FETCH, |
||||
) |
||||
|
||||
private class StubMetadataUpdater extends MockMetadataUpdater { |
||||
override def fetchNodes(): util.List[Node] = Collections.emptyList() |
||||
|
||||
override def isUpdateNeeded: Boolean = false |
||||
|
||||
override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {} |
||||
} |
||||
} |
@ -0,0 +1,183 @@
@@ -0,0 +1,183 @@
|
||||
/* |
||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||
* contributor license agreements. See the NOTICE file distributed with |
||||
* this work for additional information regarding copyright ownership. |
||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||
* (the "License"); you may not use this file except in compliance with |
||||
* the License. You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.apache.kafka.raft; |
||||
|
||||
import org.apache.kafka.clients.ClientResponse; |
||||
import org.apache.kafka.clients.KafkaClient; |
||||
import org.apache.kafka.common.Node; |
||||
import org.apache.kafka.common.message.BeginQuorumEpochRequestData; |
||||
import org.apache.kafka.common.message.EndQuorumEpochRequestData; |
||||
import org.apache.kafka.common.message.FetchRequestData; |
||||
import org.apache.kafka.common.message.FetchSnapshotRequestData; |
||||
import org.apache.kafka.common.message.VoteRequestData; |
||||
import org.apache.kafka.common.protocol.ApiKeys; |
||||
import org.apache.kafka.common.protocol.ApiMessage; |
||||
import org.apache.kafka.common.protocol.Errors; |
||||
import org.apache.kafka.common.requests.AbstractRequest; |
||||
import org.apache.kafka.common.requests.BeginQuorumEpochRequest; |
||||
import org.apache.kafka.common.requests.EndQuorumEpochRequest; |
||||
import org.apache.kafka.common.requests.FetchRequest; |
||||
import org.apache.kafka.common.requests.FetchSnapshotRequest; |
||||
import org.apache.kafka.common.requests.VoteRequest; |
||||
import org.apache.kafka.common.utils.Time; |
||||
import org.apache.kafka.server.util.InterBrokerSendThread; |
||||
import org.apache.kafka.server.util.RequestAndCompletionHandler; |
||||
import org.slf4j.Logger; |
||||
import org.slf4j.LoggerFactory; |
||||
|
||||
import java.util.ArrayList; |
||||
import java.util.Collection; |
||||
import java.util.HashMap; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
import java.util.Queue; |
||||
import java.util.concurrent.ConcurrentLinkedQueue; |
||||
import java.util.concurrent.atomic.AtomicInteger; |
||||
|
||||
public class KafkaNetworkChannel implements NetworkChannel { |
||||
|
||||
static class SendThread extends InterBrokerSendThread { |
||||
|
||||
private Queue<RequestAndCompletionHandler> queue = new ConcurrentLinkedQueue<>(); |
||||
|
||||
public SendThread(String name, KafkaClient networkClient, int requestTimeoutMs, Time time, boolean isInterruptible) { |
||||
super(name, networkClient, requestTimeoutMs, time, isInterruptible); |
||||
} |
||||
|
||||
@Override |
||||
public Collection<RequestAndCompletionHandler> generateRequests() { |
||||
List<RequestAndCompletionHandler> list = new ArrayList<>(); |
||||
while (true) { |
||||
RequestAndCompletionHandler request = queue.poll(); |
||||
if (request == null) { |
||||
return list; |
||||
} else { |
||||
list.add(request); |
||||
} |
||||
} |
||||
} |
||||
|
||||
public void sendRequest(RequestAndCompletionHandler request) { |
||||
queue.add(request); |
||||
wakeup(); |
||||
} |
||||
} |
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(KafkaNetworkChannel.class); |
||||
|
||||
private final SendThread requestThread; |
||||
|
||||
private final AtomicInteger correlationIdCounter = new AtomicInteger(0); |
||||
private final Map<Integer, Node> endpoints = new HashMap<>(); |
||||
|
||||
public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) { |
||||
this.requestThread = new SendThread( |
||||
threadNamePrefix + "-outbound-request-thread", |
||||
client, |
||||
requestTimeoutMs, |
||||
time, |
||||
false |
||||
); |
||||
} |
||||
|
||||
@Override |
||||
public int newCorrelationId() { |
||||
return correlationIdCounter.getAndIncrement(); |
||||
} |
||||
|
||||
@Override |
||||
public void send(RaftRequest.Outbound request) { |
||||
Node node = endpoints.get(request.destinationId()); |
||||
if (node != null) { |
||||
requestThread.sendRequest(new RequestAndCompletionHandler( |
||||
request.createdTimeMs, |
||||
node, |
||||
buildRequest(request.data), |
||||
response -> sendOnComplete(request, response) |
||||
)); |
||||
} else |
||||
sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)); |
||||
} |
||||
|
||||
private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) { |
||||
RaftResponse.Inbound response = new RaftResponse.Inbound( |
||||
request.correlationId, |
||||
message, |
||||
request.destinationId() |
||||
); |
||||
request.completion.complete(response); |
||||
} |
||||
|
||||
private void sendOnComplete(RaftRequest.Outbound request, ClientResponse clientResponse) { |
||||
ApiMessage response; |
||||
if (clientResponse.versionMismatch() != null) { |
||||
log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch()); |
||||
response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION); |
||||
} else if (clientResponse.authenticationException() != null) { |
||||
// For now we treat authentication errors as retriable. We use the
|
||||
// `NETWORK_EXCEPTION` error code for lack of a good alternative.
|
||||
// Note that `NodeToControllerChannelManager` will still log the
|
||||
// authentication errors so that users have a chance to fix the problem.
|
||||
log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException()); |
||||
response = errorResponse(request.data, Errors.NETWORK_EXCEPTION); |
||||
} else if (clientResponse.wasDisconnected()) { |
||||
response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE); |
||||
} else { |
||||
response = clientResponse.responseBody().data(); |
||||
} |
||||
sendCompleteFuture(request, response); |
||||
} |
||||
|
||||
private ApiMessage errorResponse(ApiMessage request, Errors error) { |
||||
ApiKeys apiKey = ApiKeys.forId(request.apiKey()); |
||||
return RaftUtil.errorResponse(apiKey, error); |
||||
} |
||||
|
||||
@Override |
||||
public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) { |
||||
Node node = new Node(id, spec.address.getHostString(), spec.address.getPort()); |
||||
endpoints.put(id, node); |
||||
} |
||||
|
||||
public void start() { |
||||
requestThread.start(); |
||||
} |
||||
|
||||
@Override |
||||
public void close() throws InterruptedException { |
||||
requestThread.shutdown(); |
||||
} |
||||
|
||||
// Visible for testing
|
||||
public void pollOnce() { |
||||
requestThread.doWork(); |
||||
} |
||||
|
||||
static AbstractRequest.Builder<? extends AbstractRequest> buildRequest(ApiMessage requestData) { |
||||
if (requestData instanceof VoteRequestData) |
||||
return new VoteRequest.Builder((VoteRequestData) requestData); |
||||
if (requestData instanceof BeginQuorumEpochRequestData) |
||||
return new BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData); |
||||
if (requestData instanceof EndQuorumEpochRequestData) |
||||
return new EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData); |
||||
if (requestData instanceof FetchRequestData) |
||||
return new FetchRequest.SimpleBuilder((FetchRequestData) requestData); |
||||
if (requestData instanceof FetchSnapshotRequestData) |
||||
return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) requestData); |
||||
throw new IllegalArgumentException("Unexpected type for requestData: " + requestData); |
||||
} |
||||
} |
@ -0,0 +1,323 @@
@@ -0,0 +1,323 @@
|
||||
/* |
||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||
* contributor license agreements. See the NOTICE file distributed with |
||||
* this work for additional information regarding copyright ownership. |
||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||
* (the "License"); you may not use this file except in compliance with |
||||
* the License. You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.apache.kafka.raft; |
||||
|
||||
import org.apache.kafka.clients.MockClient; |
||||
import org.apache.kafka.clients.NodeApiVersions; |
||||
import org.apache.kafka.common.Node; |
||||
import org.apache.kafka.common.TopicPartition; |
||||
import org.apache.kafka.common.Uuid; |
||||
import org.apache.kafka.common.message.ApiVersionsResponseData; |
||||
import org.apache.kafka.common.message.BeginQuorumEpochResponseData; |
||||
import org.apache.kafka.common.message.EndQuorumEpochResponseData; |
||||
import org.apache.kafka.common.message.FetchRequestData; |
||||
import org.apache.kafka.common.message.FetchResponseData; |
||||
import org.apache.kafka.common.message.VoteResponseData; |
||||
import org.apache.kafka.common.protocol.ApiKeys; |
||||
import org.apache.kafka.common.protocol.ApiMessage; |
||||
import org.apache.kafka.common.protocol.Errors; |
||||
import org.apache.kafka.common.requests.AbstractRequest; |
||||
import org.apache.kafka.common.requests.AbstractResponse; |
||||
import org.apache.kafka.common.requests.ApiVersionsResponse; |
||||
import org.apache.kafka.common.requests.BeginQuorumEpochRequest; |
||||
import org.apache.kafka.common.requests.BeginQuorumEpochResponse; |
||||
import org.apache.kafka.common.requests.EndQuorumEpochRequest; |
||||
import org.apache.kafka.common.requests.EndQuorumEpochResponse; |
||||
import org.apache.kafka.common.requests.FetchRequest; |
||||
import org.apache.kafka.common.requests.FetchResponse; |
||||
import org.apache.kafka.common.requests.VoteRequest; |
||||
import org.apache.kafka.common.requests.VoteResponse; |
||||
import org.apache.kafka.common.utils.MockTime; |
||||
import org.apache.kafka.common.utils.Time; |
||||
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource; |
||||
import org.junit.jupiter.api.BeforeEach; |
||||
import org.junit.jupiter.api.Test; |
||||
import org.junit.jupiter.params.ParameterizedTest; |
||||
|
||||
import java.net.InetSocketAddress; |
||||
import java.util.Collections; |
||||
import java.util.List; |
||||
import java.util.concurrent.ExecutionException; |
||||
import java.util.stream.Collectors; |
||||
|
||||
import static java.util.Arrays.asList; |
||||
import static org.junit.jupiter.api.Assertions.assertEquals; |
||||
import static org.junit.jupiter.api.Assertions.assertTrue; |
||||
|
||||
public class KafkaNetworkChannelTest { |
||||
|
||||
private static class StubMetadataUpdater implements MockClient.MockMetadataUpdater { |
||||
|
||||
@Override |
||||
public List<Node> fetchNodes() { |
||||
return Collections.emptyList(); |
||||
} |
||||
|
||||
@Override |
||||
public boolean isUpdateNeeded() { |
||||
return false; |
||||
} |
||||
|
||||
@Override |
||||
public void update(Time time, MockClient.MetadataUpdate update) { } |
||||
} |
||||
|
||||
private static final List<ApiKeys> RAFT_APIS = asList( |
||||
ApiKeys.VOTE, |
||||
ApiKeys.BEGIN_QUORUM_EPOCH, |
||||
ApiKeys.END_QUORUM_EPOCH, |
||||
ApiKeys.FETCH |
||||
); |
||||
|
||||
private final String clusterId = "clusterId"; |
||||
private final int requestTimeoutMs = 30000; |
||||
private final Time time = new MockTime(); |
||||
private final MockClient client = new MockClient(time, new StubMetadataUpdater()); |
||||
private final TopicPartition topicPartition = new TopicPartition("topic", 0); |
||||
private final Uuid topicId = Uuid.randomUuid(); |
||||
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft"); |
||||
|
||||
@BeforeEach |
||||
public void setupSupportedApis() { |
||||
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS.stream().map( |
||||
ApiVersionsResponse::toApiVersion).collect(Collectors.toList()); |
||||
client.setNodeApiVersions(NodeApiVersions.create(supportedApis)); |
||||
} |
||||
|
||||
@Test |
||||
public void testSendToUnknownDestination() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
assertBrokerNotAvailable(destinationId); |
||||
} |
||||
|
||||
@Test |
||||
public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
client.backoff(destinationNode, 500); |
||||
assertBrokerNotAvailable(destinationId); |
||||
} |
||||
|
||||
@Test |
||||
public void testWakeupClientOnSend() throws InterruptedException, ExecutionException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
|
||||
client.enableBlockingUntilWakeup(1); |
||||
|
||||
Thread ioThread = new Thread(() -> { |
||||
// Block in poll until we get the expected wakeup
|
||||
channel.pollOnce(); |
||||
|
||||
// Poll a second time to send request and receive response
|
||||
channel.pollOnce(); |
||||
}); |
||||
|
||||
AbstractResponse response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)); |
||||
client.prepareResponseFrom(response, destinationNode, false); |
||||
|
||||
ioThread.start(); |
||||
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId); |
||||
|
||||
ioThread.join(); |
||||
assertResponseCompleted(request, Errors.INVALID_REQUEST); |
||||
} |
||||
|
||||
@Test |
||||
public void testSendAndDisconnect() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
|
||||
for (ApiKeys apiKey : RAFT_APIS) { |
||||
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)); |
||||
client.prepareResponseFrom(response, destinationNode, true); |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
|
||||
for (ApiKeys apiKey : RAFT_APIS) { |
||||
client.createPendingAuthenticationError(destinationNode, 100); |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION); |
||||
|
||||
// reset to clear backoff time
|
||||
client.reset(); |
||||
} |
||||
} |
||||
|
||||
private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException { |
||||
for (ApiKeys apiKey : RAFT_APIS) { |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
|
||||
for (ApiKeys apiKey : RAFT_APIS) { |
||||
Errors expectedError = Errors.INVALID_REQUEST; |
||||
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError)); |
||||
client.prepareResponseFrom(response, destinationNode); |
||||
System.out.println("api key " + apiKey + ", response " + response); |
||||
sendAndAssertErrorResponse(apiKey, destinationId, expectedError); |
||||
} |
||||
} |
||||
|
||||
@Test |
||||
public void testUnsupportedVersionError() throws ExecutionException, InterruptedException { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
|
||||
for (ApiKeys apiKey : RAFT_APIS) { |
||||
client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey); |
||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION); |
||||
} |
||||
} |
||||
|
||||
@ParameterizedTest |
||||
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) |
||||
public void testFetchRequestDowngrade(short version) { |
||||
int destinationId = 2; |
||||
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||
sendTestRequest(ApiKeys.FETCH, destinationId); |
||||
channel.pollOnce(); |
||||
|
||||
assertEquals(1, client.requests().size()); |
||||
AbstractRequest request = client.requests().peek().requestBuilder().build(version); |
||||
|
||||
if (version < 15) { |
||||
assertTrue(((FetchRequest) request).data().replicaId() == 1); |
||||
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == -1); |
||||
} else { |
||||
assertTrue(((FetchRequest) request).data().replicaId() == -1); |
||||
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == 1); |
||||
} |
||||
} |
||||
|
||||
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) { |
||||
int correlationId = channel.newCorrelationId(); |
||||
long createdTimeMs = time.milliseconds(); |
||||
ApiMessage apiRequest = buildTestRequest(apiKey); |
||||
RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs); |
||||
channel.send(request); |
||||
return request; |
||||
} |
||||
|
||||
private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException { |
||||
assertTrue(request.completion.isDone()); |
||||
|
||||
RaftResponse.Inbound response = request.completion.get(); |
||||
assertEquals(request.destinationId(), response.sourceId()); |
||||
assertEquals(request.correlationId, response.correlationId); |
||||
assertEquals(request.data.apiKey(), response.data.apiKey()); |
||||
assertEquals(expectedError, extractError(response.data)); |
||||
} |
||||
|
||||
private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException { |
||||
RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId); |
||||
channel.pollOnce(); |
||||
assertResponseCompleted(request, error); |
||||
} |
||||
|
||||
private ApiMessage buildTestRequest(ApiKeys key) { |
||||
int leaderEpoch = 5; |
||||
int leaderId = 1; |
||||
switch (key) { |
||||
case BEGIN_QUORUM_EPOCH: |
||||
return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId); |
||||
case END_QUORUM_EPOCH: |
||||
return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch, |
||||
Collections.singletonList(2)); |
||||
case VOTE: |
||||
int lastEpoch = 4; |
||||
return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329); |
||||
case FETCH: |
||||
FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> { |
||||
fetchPartition |
||||
.setCurrentLeaderEpoch(5) |
||||
.setFetchOffset(333) |
||||
.setLastFetchedEpoch(5); |
||||
}); |
||||
request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1)); |
||||
return request; |
||||
default: |
||||
throw new AssertionError("Unexpected api " + key); |
||||
} |
||||
} |
||||
|
||||
private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) { |
||||
switch (key) { |
||||
case BEGIN_QUORUM_EPOCH: |
||||
return new BeginQuorumEpochResponseData().setErrorCode(error.code()); |
||||
case END_QUORUM_EPOCH: |
||||
return new EndQuorumEpochResponseData().setErrorCode(error.code()); |
||||
case VOTE: |
||||
return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); |
||||
case FETCH: |
||||
return new FetchResponseData().setErrorCode(error.code()); |
||||
default: |
||||
throw new AssertionError("Unexpected api " + key); |
||||
} |
||||
} |
||||
|
||||
private Errors extractError(ApiMessage response) { |
||||
short code; |
||||
if (response instanceof BeginQuorumEpochResponseData) |
||||
code = ((BeginQuorumEpochResponseData) response).errorCode(); |
||||
else if (response instanceof EndQuorumEpochResponseData) |
||||
code = ((EndQuorumEpochResponseData) response).errorCode(); |
||||
else if (response instanceof FetchResponseData) |
||||
code = ((FetchResponseData) response).errorCode(); |
||||
else if (response instanceof VoteResponseData) |
||||
code = ((VoteResponseData) response).errorCode(); |
||||
else |
||||
throw new IllegalArgumentException("Unexpected type for responseData: " + response); |
||||
return Errors.forCode(code); |
||||
} |
||||
|
||||
private AbstractResponse buildResponse(ApiMessage responseData) { |
||||
if (responseData instanceof VoteResponseData) |
||||
return new VoteResponse((VoteResponseData) responseData); |
||||
if (responseData instanceof BeginQuorumEpochResponseData) |
||||
return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData); |
||||
if (responseData instanceof EndQuorumEpochResponseData) |
||||
return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData); |
||||
if (responseData instanceof FetchResponseData) |
||||
return new FetchResponse((FetchResponseData) responseData); |
||||
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); |
||||
} |
||||
} |
Loading…
Reference in new issue