Browse Source
This is now possible since `InterBrokerSend` was moved from `core` to `server-common`. Also rewrite/move `KafkaNetworkChannelTest`. The scala version of `KafkaNetworkChannelTest` passed with the changes here (before I deleted it). Reviewers: Justine Olshan <jolshan@confluent.io>, José Armando García Sancio <jsancio@users.noreply.github.com>pull/14569/head
Ismael Juma
11 months ago
committed by
GitHub
7 changed files with 510 additions and 512 deletions
@ -1,191 +0,0 @@ |
|||||||
/* |
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
|
||||||
* contributor license agreements. See the NOTICE file distributed with |
|
||||||
* this work for additional information regarding copyright ownership. |
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
|
||||||
* (the "License"); you may not use this file except in compliance with |
|
||||||
* the License. You may obtain a copy of the License at |
|
||||||
* |
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0 |
|
||||||
* |
|
||||||
* Unless required by applicable law or agreed to in writing, software |
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, |
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||||
* See the License for the specific language governing permissions and |
|
||||||
* limitations under the License. |
|
||||||
*/ |
|
||||||
package kafka.raft |
|
||||||
|
|
||||||
import kafka.utils.Logging |
|
||||||
import org.apache.kafka.clients.{ClientResponse, KafkaClient} |
|
||||||
import org.apache.kafka.common.Node |
|
||||||
import org.apache.kafka.common.message._ |
|
||||||
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} |
|
||||||
import org.apache.kafka.common.requests._ |
|
||||||
import org.apache.kafka.common.utils.Time |
|
||||||
import org.apache.kafka.raft.RaftConfig.InetAddressSpec |
|
||||||
import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} |
|
||||||
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} |
|
||||||
|
|
||||||
import java.util |
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue |
|
||||||
import java.util.concurrent.atomic.AtomicInteger |
|
||||||
import scala.collection.mutable |
|
||||||
|
|
||||||
object KafkaNetworkChannel { |
|
||||||
|
|
||||||
private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = { |
|
||||||
requestData match { |
|
||||||
case voteRequest: VoteRequestData => |
|
||||||
new VoteRequest.Builder(voteRequest) |
|
||||||
case beginEpochRequest: BeginQuorumEpochRequestData => |
|
||||||
new BeginQuorumEpochRequest.Builder(beginEpochRequest) |
|
||||||
case endEpochRequest: EndQuorumEpochRequestData => |
|
||||||
new EndQuorumEpochRequest.Builder(endEpochRequest) |
|
||||||
case fetchRequest: FetchRequestData => |
|
||||||
new FetchRequest.SimpleBuilder(fetchRequest) |
|
||||||
case fetchSnapshotRequest: FetchSnapshotRequestData => |
|
||||||
new FetchSnapshotRequest.Builder(fetchSnapshotRequest) |
|
||||||
case _ => |
|
||||||
throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
} |
|
||||||
|
|
||||||
private[raft] class RaftSendThread( |
|
||||||
name: String, |
|
||||||
networkClient: KafkaClient, |
|
||||||
requestTimeoutMs: Int, |
|
||||||
time: Time, |
|
||||||
isInterruptible: Boolean = true |
|
||||||
) extends InterBrokerSendThread( |
|
||||||
name, |
|
||||||
networkClient, |
|
||||||
requestTimeoutMs, |
|
||||||
time, |
|
||||||
isInterruptible |
|
||||||
) { |
|
||||||
private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() |
|
||||||
|
|
||||||
def generateRequests(): util.Collection[RequestAndCompletionHandler] = { |
|
||||||
val list = new util.ArrayList[RequestAndCompletionHandler]() |
|
||||||
while (true) { |
|
||||||
val request = queue.poll() |
|
||||||
if (request == null) { |
|
||||||
return list |
|
||||||
} else { |
|
||||||
list.add(request) |
|
||||||
} |
|
||||||
} |
|
||||||
list |
|
||||||
} |
|
||||||
|
|
||||||
def sendRequest(request: RequestAndCompletionHandler): Unit = { |
|
||||||
queue.add(request) |
|
||||||
wakeup() |
|
||||||
} |
|
||||||
|
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
class KafkaNetworkChannel( |
|
||||||
time: Time, |
|
||||||
client: KafkaClient, |
|
||||||
requestTimeoutMs: Int, |
|
||||||
threadNamePrefix: String |
|
||||||
) extends NetworkChannel with Logging { |
|
||||||
import KafkaNetworkChannel._ |
|
||||||
|
|
||||||
type ResponseHandler = AbstractResponse => Unit |
|
||||||
|
|
||||||
private val correlationIdCounter = new AtomicInteger(0) |
|
||||||
private val endpoints = mutable.HashMap.empty[Int, Node] |
|
||||||
|
|
||||||
private val requestThread = new RaftSendThread( |
|
||||||
name = threadNamePrefix + "-outbound-request-thread", |
|
||||||
networkClient = client, |
|
||||||
requestTimeoutMs = requestTimeoutMs, |
|
||||||
time = time, |
|
||||||
isInterruptible = false |
|
||||||
) |
|
||||||
|
|
||||||
override def send(request: RaftRequest.Outbound): Unit = { |
|
||||||
def completeFuture(message: ApiMessage): Unit = { |
|
||||||
val response = new RaftResponse.Inbound( |
|
||||||
request.correlationId, |
|
||||||
message, |
|
||||||
request.destinationId |
|
||||||
) |
|
||||||
request.completion.complete(response) |
|
||||||
} |
|
||||||
|
|
||||||
def onComplete(clientResponse: ClientResponse): Unit = { |
|
||||||
val response = if (clientResponse.versionMismatch != null) { |
|
||||||
error(s"Request $request failed due to unsupported version error", |
|
||||||
clientResponse.versionMismatch) |
|
||||||
errorResponse(request.data, Errors.UNSUPPORTED_VERSION) |
|
||||||
} else if (clientResponse.authenticationException != null) { |
|
||||||
// For now we treat authentication errors as retriable. We use the |
|
||||||
// `NETWORK_EXCEPTION` error code for lack of a good alternative. |
|
||||||
// Note that `NodeToControllerChannelManager` will still log the |
|
||||||
// authentication errors so that users have a chance to fix the problem. |
|
||||||
error(s"Request $request failed due to authentication error", |
|
||||||
clientResponse.authenticationException) |
|
||||||
errorResponse(request.data, Errors.NETWORK_EXCEPTION) |
|
||||||
} else if (clientResponse.wasDisconnected()) { |
|
||||||
errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) |
|
||||||
} else { |
|
||||||
clientResponse.responseBody.data |
|
||||||
} |
|
||||||
completeFuture(response) |
|
||||||
} |
|
||||||
|
|
||||||
endpoints.get(request.destinationId) match { |
|
||||||
case Some(node) => |
|
||||||
requestThread.sendRequest(new RequestAndCompletionHandler( |
|
||||||
request.createdTimeMs, |
|
||||||
node, |
|
||||||
buildRequest(request.data), |
|
||||||
onComplete |
|
||||||
)) |
|
||||||
|
|
||||||
case None => |
|
||||||
completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// Visible for testing |
|
||||||
private[raft] def pollOnce(): Unit = { |
|
||||||
requestThread.doWork() |
|
||||||
} |
|
||||||
|
|
||||||
override def newCorrelationId(): Int = { |
|
||||||
correlationIdCounter.getAndIncrement() |
|
||||||
} |
|
||||||
|
|
||||||
private def errorResponse( |
|
||||||
request: ApiMessage, |
|
||||||
error: Errors |
|
||||||
): ApiMessage = { |
|
||||||
val apiKey = ApiKeys.forId(request.apiKey) |
|
||||||
RaftUtil.errorResponse(apiKey, error) |
|
||||||
} |
|
||||||
|
|
||||||
override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = { |
|
||||||
val node = new Node(id, spec.address.getHostString, spec.address.getPort) |
|
||||||
endpoints.put(id, node) |
|
||||||
} |
|
||||||
|
|
||||||
def start(): Unit = { |
|
||||||
requestThread.start() |
|
||||||
} |
|
||||||
|
|
||||||
def initiateShutdown(): Unit = { |
|
||||||
requestThread.initiateShutdown() |
|
||||||
} |
|
||||||
|
|
||||||
override def close(): Unit = { |
|
||||||
requestThread.shutdown() |
|
||||||
} |
|
||||||
} |
|
@ -1,316 +0,0 @@ |
|||||||
/* |
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more |
|
||||||
* contributor license agreements. See the NOTICE file distributed with |
|
||||||
* this work for additional information regarding copyright ownership. |
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0 |
|
||||||
* (the "License"); you may not use this file except in compliance with |
|
||||||
* the License. You may obtain a copy of the License at |
|
||||||
* |
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0 |
|
||||||
* |
|
||||||
* Unless required by applicable law or agreed to in writing, software |
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, |
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||||
* See the License for the specific language governing permissions and |
|
||||||
* limitations under the License. |
|
||||||
*/ |
|
||||||
package kafka.raft |
|
||||||
|
|
||||||
import java.net.InetSocketAddress |
|
||||||
import java.util |
|
||||||
import java.util.Collections |
|
||||||
import org.apache.kafka.clients.MockClient.MockMetadataUpdater |
|
||||||
import org.apache.kafka.clients.{MockClient, NodeApiVersions} |
|
||||||
import org.apache.kafka.common.message.FetchRequestData.ReplicaState |
|
||||||
import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} |
|
||||||
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} |
|
||||||
import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchRequest, FetchResponse, VoteRequest, VoteResponse} |
|
||||||
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource |
|
||||||
import org.apache.kafka.common.utils.{MockTime, Time} |
|
||||||
import org.apache.kafka.common.{Node, TopicPartition, Uuid} |
|
||||||
import org.apache.kafka.raft.RaftConfig.InetAddressSpec |
|
||||||
import org.apache.kafka.raft.{RaftRequest, RaftUtil} |
|
||||||
import org.junit.jupiter.api.Assertions._ |
|
||||||
import org.junit.jupiter.api.{BeforeEach, Test} |
|
||||||
import org.junit.jupiter.params.ParameterizedTest |
|
||||||
|
|
||||||
import scala.jdk.CollectionConverters._ |
|
||||||
|
|
||||||
class KafkaNetworkChannelTest { |
|
||||||
import KafkaNetworkChannelTest._ |
|
||||||
|
|
||||||
private val clusterId = "clusterId" |
|
||||||
private val requestTimeoutMs = 30000 |
|
||||||
private val time = new MockTime() |
|
||||||
private val client = new MockClient(time, new StubMetadataUpdater) |
|
||||||
private val topicPartition = new TopicPartition("topic", 0) |
|
||||||
private val topicId = Uuid.randomUuid() |
|
||||||
private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft") |
|
||||||
|
|
||||||
@BeforeEach |
|
||||||
def setupSupportedApis(): Unit = { |
|
||||||
val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion) |
|
||||||
client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava)) |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testSendToUnknownDestination(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
assertBrokerNotAvailable(destinationId) |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testSendToBlackedOutDestination(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
client.backoff(destinationNode, 500) |
|
||||||
assertBrokerNotAvailable(destinationId) |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testWakeupClientOnSend(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
|
|
||||||
client.enableBlockingUntilWakeup(1) |
|
||||||
|
|
||||||
val ioThread = new Thread() { |
|
||||||
override def run(): Unit = { |
|
||||||
// Block in poll until we get the expected wakeup |
|
||||||
channel.pollOnce() |
|
||||||
|
|
||||||
// Poll a second time to send request and receive response |
|
||||||
channel.pollOnce() |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)) |
|
||||||
client.prepareResponseFrom(response, destinationNode, false) |
|
||||||
|
|
||||||
ioThread.start() |
|
||||||
val request = sendTestRequest(ApiKeys.FETCH, destinationId) |
|
||||||
|
|
||||||
ioThread.join() |
|
||||||
assertResponseCompleted(request, Errors.INVALID_REQUEST) |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testSendAndDisconnect(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
|
|
||||||
for (apiKey <- RaftApis) { |
|
||||||
val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) |
|
||||||
client.prepareResponseFrom(response, destinationNode, true) |
|
||||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testSendAndFailAuthentication(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
|
|
||||||
for (apiKey <- RaftApis) { |
|
||||||
client.createPendingAuthenticationError(destinationNode, 100) |
|
||||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION) |
|
||||||
|
|
||||||
// reset to clear backoff time |
|
||||||
client.reset() |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
private def assertBrokerNotAvailable(destinationId: Int): Unit = { |
|
||||||
for (apiKey <- RaftApis) { |
|
||||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testSendAndReceiveOutboundRequest(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
|
|
||||||
for (apiKey <- RaftApis) { |
|
||||||
val expectedError = Errors.INVALID_REQUEST |
|
||||||
val response = buildResponse(buildTestErrorResponse(apiKey, expectedError)) |
|
||||||
client.prepareResponseFrom(response, destinationNode) |
|
||||||
sendAndAssertErrorResponse(apiKey, destinationId, expectedError) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
@Test |
|
||||||
def testUnsupportedVersionError(): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
|
|
||||||
for (apiKey <- RaftApis) { |
|
||||||
client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey) |
|
||||||
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
@ParameterizedTest |
|
||||||
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) |
|
||||||
def testFetchRequestDowngrade(version: Short): Unit = { |
|
||||||
val destinationId = 2 |
|
||||||
val destinationNode = new Node(destinationId, "127.0.0.1", 9092) |
|
||||||
channel.updateEndpoint(destinationId, new InetAddressSpec( |
|
||||||
new InetSocketAddress(destinationNode.host, destinationNode.port))) |
|
||||||
sendTestRequest(ApiKeys.FETCH, destinationId) |
|
||||||
channel.pollOnce() |
|
||||||
|
|
||||||
assertEquals(1, client.requests.size) |
|
||||||
val request = client.requests.peek.requestBuilder.build(version) |
|
||||||
|
|
||||||
if (version < 15) { |
|
||||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == 1) |
|
||||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == -1) |
|
||||||
} else { |
|
||||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaId == -1) |
|
||||||
assertTrue(request.asInstanceOf[FetchRequest].data.replicaState.replicaId == 1) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
private def sendTestRequest( |
|
||||||
apiKey: ApiKeys, |
|
||||||
destinationId: Int, |
|
||||||
): RaftRequest.Outbound = { |
|
||||||
val correlationId = channel.newCorrelationId() |
|
||||||
val createdTimeMs = time.milliseconds() |
|
||||||
val apiRequest = buildTestRequest(apiKey) |
|
||||||
val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs) |
|
||||||
channel.send(request) |
|
||||||
request |
|
||||||
} |
|
||||||
|
|
||||||
private def assertResponseCompleted( |
|
||||||
request: RaftRequest.Outbound, |
|
||||||
expectedError: Errors |
|
||||||
): Unit = { |
|
||||||
assertTrue(request.completion.isDone) |
|
||||||
|
|
||||||
val response = request.completion.get() |
|
||||||
assertEquals(request.destinationId, response.sourceId) |
|
||||||
assertEquals(request.correlationId, response.correlationId) |
|
||||||
assertEquals(request.data.apiKey, response.data.apiKey) |
|
||||||
assertEquals(expectedError, extractError(response.data)) |
|
||||||
} |
|
||||||
|
|
||||||
private def sendAndAssertErrorResponse( |
|
||||||
apiKey: ApiKeys, |
|
||||||
destinationId: Int, |
|
||||||
error: Errors |
|
||||||
): Unit = { |
|
||||||
val request = sendTestRequest(apiKey, destinationId) |
|
||||||
channel.pollOnce() |
|
||||||
assertResponseCompleted(request, error) |
|
||||||
} |
|
||||||
|
|
||||||
private def buildTestRequest(key: ApiKeys): ApiMessage = { |
|
||||||
val leaderEpoch = 5 |
|
||||||
val leaderId = 1 |
|
||||||
key match { |
|
||||||
case ApiKeys.BEGIN_QUORUM_EPOCH => |
|
||||||
BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId) |
|
||||||
|
|
||||||
case ApiKeys.END_QUORUM_EPOCH => |
|
||||||
EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, |
|
||||||
leaderEpoch, Collections.singletonList(2)) |
|
||||||
|
|
||||||
case ApiKeys.VOTE => |
|
||||||
val lastEpoch = 4 |
|
||||||
VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329) |
|
||||||
|
|
||||||
case ApiKeys.FETCH => |
|
||||||
val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => { |
|
||||||
fetchPartition |
|
||||||
.setCurrentLeaderEpoch(5) |
|
||||||
.setFetchOffset(333) |
|
||||||
.setLastFetchedEpoch(5) |
|
||||||
}) |
|
||||||
request.setReplicaState(new ReplicaState().setReplicaId(1)) |
|
||||||
|
|
||||||
case _ => |
|
||||||
throw new AssertionError(s"Unexpected api $key") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = { |
|
||||||
key match { |
|
||||||
case ApiKeys.BEGIN_QUORUM_EPOCH => |
|
||||||
new BeginQuorumEpochResponseData() |
|
||||||
.setErrorCode(error.code) |
|
||||||
|
|
||||||
case ApiKeys.END_QUORUM_EPOCH => |
|
||||||
new EndQuorumEpochResponseData() |
|
||||||
.setErrorCode(error.code) |
|
||||||
|
|
||||||
case ApiKeys.VOTE => |
|
||||||
VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); |
|
||||||
|
|
||||||
case ApiKeys.FETCH => |
|
||||||
new FetchResponseData() |
|
||||||
.setErrorCode(error.code) |
|
||||||
|
|
||||||
case _ => |
|
||||||
throw new AssertionError(s"Unexpected api $key") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
private def extractError(response: ApiMessage): Errors = { |
|
||||||
val code = (response: @unchecked) match { |
|
||||||
case res: BeginQuorumEpochResponseData => res.errorCode |
|
||||||
case res: EndQuorumEpochResponseData => res.errorCode |
|
||||||
case res: FetchResponseData => res.errorCode |
|
||||||
case res: VoteResponseData => res.errorCode |
|
||||||
} |
|
||||||
Errors.forCode(code) |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
def buildResponse(responseData: ApiMessage): AbstractResponse = { |
|
||||||
responseData match { |
|
||||||
case voteResponse: VoteResponseData => |
|
||||||
new VoteResponse(voteResponse) |
|
||||||
case beginEpochResponse: BeginQuorumEpochResponseData => |
|
||||||
new BeginQuorumEpochResponse(beginEpochResponse) |
|
||||||
case endEpochResponse: EndQuorumEpochResponseData => |
|
||||||
new EndQuorumEpochResponse(endEpochResponse) |
|
||||||
case fetchResponse: FetchResponseData => |
|
||||||
new FetchResponse(fetchResponse) |
|
||||||
case _ => |
|
||||||
throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
} |
|
||||||
|
|
||||||
object KafkaNetworkChannelTest { |
|
||||||
val RaftApis = Seq( |
|
||||||
ApiKeys.VOTE, |
|
||||||
ApiKeys.BEGIN_QUORUM_EPOCH, |
|
||||||
ApiKeys.END_QUORUM_EPOCH, |
|
||||||
ApiKeys.FETCH, |
|
||||||
) |
|
||||||
|
|
||||||
private class StubMetadataUpdater extends MockMetadataUpdater { |
|
||||||
override def fetchNodes(): util.List[Node] = Collections.emptyList() |
|
||||||
|
|
||||||
override def isUpdateNeeded: Boolean = false |
|
||||||
|
|
||||||
override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {} |
|
||||||
} |
|
||||||
} |
|
@ -0,0 +1,183 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
package org.apache.kafka.raft; |
||||||
|
|
||||||
|
import org.apache.kafka.clients.ClientResponse; |
||||||
|
import org.apache.kafka.clients.KafkaClient; |
||||||
|
import org.apache.kafka.common.Node; |
||||||
|
import org.apache.kafka.common.message.BeginQuorumEpochRequestData; |
||||||
|
import org.apache.kafka.common.message.EndQuorumEpochRequestData; |
||||||
|
import org.apache.kafka.common.message.FetchRequestData; |
||||||
|
import org.apache.kafka.common.message.FetchSnapshotRequestData; |
||||||
|
import org.apache.kafka.common.message.VoteRequestData; |
||||||
|
import org.apache.kafka.common.protocol.ApiKeys; |
||||||
|
import org.apache.kafka.common.protocol.ApiMessage; |
||||||
|
import org.apache.kafka.common.protocol.Errors; |
||||||
|
import org.apache.kafka.common.requests.AbstractRequest; |
||||||
|
import org.apache.kafka.common.requests.BeginQuorumEpochRequest; |
||||||
|
import org.apache.kafka.common.requests.EndQuorumEpochRequest; |
||||||
|
import org.apache.kafka.common.requests.FetchRequest; |
||||||
|
import org.apache.kafka.common.requests.FetchSnapshotRequest; |
||||||
|
import org.apache.kafka.common.requests.VoteRequest; |
||||||
|
import org.apache.kafka.common.utils.Time; |
||||||
|
import org.apache.kafka.server.util.InterBrokerSendThread; |
||||||
|
import org.apache.kafka.server.util.RequestAndCompletionHandler; |
||||||
|
import org.slf4j.Logger; |
||||||
|
import org.slf4j.LoggerFactory; |
||||||
|
|
||||||
|
import java.util.ArrayList; |
||||||
|
import java.util.Collection; |
||||||
|
import java.util.HashMap; |
||||||
|
import java.util.List; |
||||||
|
import java.util.Map; |
||||||
|
import java.util.Queue; |
||||||
|
import java.util.concurrent.ConcurrentLinkedQueue; |
||||||
|
import java.util.concurrent.atomic.AtomicInteger; |
||||||
|
|
||||||
|
public class KafkaNetworkChannel implements NetworkChannel { |
||||||
|
|
||||||
|
static class SendThread extends InterBrokerSendThread { |
||||||
|
|
||||||
|
private Queue<RequestAndCompletionHandler> queue = new ConcurrentLinkedQueue<>(); |
||||||
|
|
||||||
|
public SendThread(String name, KafkaClient networkClient, int requestTimeoutMs, Time time, boolean isInterruptible) { |
||||||
|
super(name, networkClient, requestTimeoutMs, time, isInterruptible); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public Collection<RequestAndCompletionHandler> generateRequests() { |
||||||
|
List<RequestAndCompletionHandler> list = new ArrayList<>(); |
||||||
|
while (true) { |
||||||
|
RequestAndCompletionHandler request = queue.poll(); |
||||||
|
if (request == null) { |
||||||
|
return list; |
||||||
|
} else { |
||||||
|
list.add(request); |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
public void sendRequest(RequestAndCompletionHandler request) { |
||||||
|
queue.add(request); |
||||||
|
wakeup(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(KafkaNetworkChannel.class); |
||||||
|
|
||||||
|
private final SendThread requestThread; |
||||||
|
|
||||||
|
private final AtomicInteger correlationIdCounter = new AtomicInteger(0); |
||||||
|
private final Map<Integer, Node> endpoints = new HashMap<>(); |
||||||
|
|
||||||
|
public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) { |
||||||
|
this.requestThread = new SendThread( |
||||||
|
threadNamePrefix + "-outbound-request-thread", |
||||||
|
client, |
||||||
|
requestTimeoutMs, |
||||||
|
time, |
||||||
|
false |
||||||
|
); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public int newCorrelationId() { |
||||||
|
return correlationIdCounter.getAndIncrement(); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void send(RaftRequest.Outbound request) { |
||||||
|
Node node = endpoints.get(request.destinationId()); |
||||||
|
if (node != null) { |
||||||
|
requestThread.sendRequest(new RequestAndCompletionHandler( |
||||||
|
request.createdTimeMs, |
||||||
|
node, |
||||||
|
buildRequest(request.data), |
||||||
|
response -> sendOnComplete(request, response) |
||||||
|
)); |
||||||
|
} else |
||||||
|
sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)); |
||||||
|
} |
||||||
|
|
||||||
|
private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) { |
||||||
|
RaftResponse.Inbound response = new RaftResponse.Inbound( |
||||||
|
request.correlationId, |
||||||
|
message, |
||||||
|
request.destinationId() |
||||||
|
); |
||||||
|
request.completion.complete(response); |
||||||
|
} |
||||||
|
|
||||||
|
private void sendOnComplete(RaftRequest.Outbound request, ClientResponse clientResponse) { |
||||||
|
ApiMessage response; |
||||||
|
if (clientResponse.versionMismatch() != null) { |
||||||
|
log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch()); |
||||||
|
response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION); |
||||||
|
} else if (clientResponse.authenticationException() != null) { |
||||||
|
// For now we treat authentication errors as retriable. We use the
|
||||||
|
// `NETWORK_EXCEPTION` error code for lack of a good alternative.
|
||||||
|
// Note that `NodeToControllerChannelManager` will still log the
|
||||||
|
// authentication errors so that users have a chance to fix the problem.
|
||||||
|
log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException()); |
||||||
|
response = errorResponse(request.data, Errors.NETWORK_EXCEPTION); |
||||||
|
} else if (clientResponse.wasDisconnected()) { |
||||||
|
response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE); |
||||||
|
} else { |
||||||
|
response = clientResponse.responseBody().data(); |
||||||
|
} |
||||||
|
sendCompleteFuture(request, response); |
||||||
|
} |
||||||
|
|
||||||
|
private ApiMessage errorResponse(ApiMessage request, Errors error) { |
||||||
|
ApiKeys apiKey = ApiKeys.forId(request.apiKey()); |
||||||
|
return RaftUtil.errorResponse(apiKey, error); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void updateEndpoint(int id, RaftConfig.InetAddressSpec spec) { |
||||||
|
Node node = new Node(id, spec.address.getHostString(), spec.address.getPort()); |
||||||
|
endpoints.put(id, node); |
||||||
|
} |
||||||
|
|
||||||
|
public void start() { |
||||||
|
requestThread.start(); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void close() throws InterruptedException { |
||||||
|
requestThread.shutdown(); |
||||||
|
} |
||||||
|
|
||||||
|
// Visible for testing
|
||||||
|
public void pollOnce() { |
||||||
|
requestThread.doWork(); |
||||||
|
} |
||||||
|
|
||||||
|
static AbstractRequest.Builder<? extends AbstractRequest> buildRequest(ApiMessage requestData) { |
||||||
|
if (requestData instanceof VoteRequestData) |
||||||
|
return new VoteRequest.Builder((VoteRequestData) requestData); |
||||||
|
if (requestData instanceof BeginQuorumEpochRequestData) |
||||||
|
return new BeginQuorumEpochRequest.Builder((BeginQuorumEpochRequestData) requestData); |
||||||
|
if (requestData instanceof EndQuorumEpochRequestData) |
||||||
|
return new EndQuorumEpochRequest.Builder((EndQuorumEpochRequestData) requestData); |
||||||
|
if (requestData instanceof FetchRequestData) |
||||||
|
return new FetchRequest.SimpleBuilder((FetchRequestData) requestData); |
||||||
|
if (requestData instanceof FetchSnapshotRequestData) |
||||||
|
return new FetchSnapshotRequest.Builder((FetchSnapshotRequestData) requestData); |
||||||
|
throw new IllegalArgumentException("Unexpected type for requestData: " + requestData); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,323 @@ |
|||||||
|
/* |
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
||||||
|
* contributor license agreements. See the NOTICE file distributed with |
||||||
|
* this work for additional information regarding copyright ownership. |
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
||||||
|
* (the "License"); you may not use this file except in compliance with |
||||||
|
* the License. You may obtain a copy of the License at |
||||||
|
* |
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
* |
||||||
|
* Unless required by applicable law or agreed to in writing, software |
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
* See the License for the specific language governing permissions and |
||||||
|
* limitations under the License. |
||||||
|
*/ |
||||||
|
package org.apache.kafka.raft; |
||||||
|
|
||||||
|
import org.apache.kafka.clients.MockClient; |
||||||
|
import org.apache.kafka.clients.NodeApiVersions; |
||||||
|
import org.apache.kafka.common.Node; |
||||||
|
import org.apache.kafka.common.TopicPartition; |
||||||
|
import org.apache.kafka.common.Uuid; |
||||||
|
import org.apache.kafka.common.message.ApiVersionsResponseData; |
||||||
|
import org.apache.kafka.common.message.BeginQuorumEpochResponseData; |
||||||
|
import org.apache.kafka.common.message.EndQuorumEpochResponseData; |
||||||
|
import org.apache.kafka.common.message.FetchRequestData; |
||||||
|
import org.apache.kafka.common.message.FetchResponseData; |
||||||
|
import org.apache.kafka.common.message.VoteResponseData; |
||||||
|
import org.apache.kafka.common.protocol.ApiKeys; |
||||||
|
import org.apache.kafka.common.protocol.ApiMessage; |
||||||
|
import org.apache.kafka.common.protocol.Errors; |
||||||
|
import org.apache.kafka.common.requests.AbstractRequest; |
||||||
|
import org.apache.kafka.common.requests.AbstractResponse; |
||||||
|
import org.apache.kafka.common.requests.ApiVersionsResponse; |
||||||
|
import org.apache.kafka.common.requests.BeginQuorumEpochRequest; |
||||||
|
import org.apache.kafka.common.requests.BeginQuorumEpochResponse; |
||||||
|
import org.apache.kafka.common.requests.EndQuorumEpochRequest; |
||||||
|
import org.apache.kafka.common.requests.EndQuorumEpochResponse; |
||||||
|
import org.apache.kafka.common.requests.FetchRequest; |
||||||
|
import org.apache.kafka.common.requests.FetchResponse; |
||||||
|
import org.apache.kafka.common.requests.VoteRequest; |
||||||
|
import org.apache.kafka.common.requests.VoteResponse; |
||||||
|
import org.apache.kafka.common.utils.MockTime; |
||||||
|
import org.apache.kafka.common.utils.Time; |
||||||
|
import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource; |
||||||
|
import org.junit.jupiter.api.BeforeEach; |
||||||
|
import org.junit.jupiter.api.Test; |
||||||
|
import org.junit.jupiter.params.ParameterizedTest; |
||||||
|
|
||||||
|
import java.net.InetSocketAddress; |
||||||
|
import java.util.Collections; |
||||||
|
import java.util.List; |
||||||
|
import java.util.concurrent.ExecutionException; |
||||||
|
import java.util.stream.Collectors; |
||||||
|
|
||||||
|
import static java.util.Arrays.asList; |
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals; |
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue; |
||||||
|
|
||||||
|
public class KafkaNetworkChannelTest { |
||||||
|
|
||||||
|
private static class StubMetadataUpdater implements MockClient.MockMetadataUpdater { |
||||||
|
|
||||||
|
@Override |
||||||
|
public List<Node> fetchNodes() { |
||||||
|
return Collections.emptyList(); |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public boolean isUpdateNeeded() { |
||||||
|
return false; |
||||||
|
} |
||||||
|
|
||||||
|
@Override |
||||||
|
public void update(Time time, MockClient.MetadataUpdate update) { } |
||||||
|
} |
||||||
|
|
||||||
|
private static final List<ApiKeys> RAFT_APIS = asList( |
||||||
|
ApiKeys.VOTE, |
||||||
|
ApiKeys.BEGIN_QUORUM_EPOCH, |
||||||
|
ApiKeys.END_QUORUM_EPOCH, |
||||||
|
ApiKeys.FETCH |
||||||
|
); |
||||||
|
|
||||||
|
private final String clusterId = "clusterId"; |
||||||
|
private final int requestTimeoutMs = 30000; |
||||||
|
private final Time time = new MockTime(); |
||||||
|
private final MockClient client = new MockClient(time, new StubMetadataUpdater()); |
||||||
|
private final TopicPartition topicPartition = new TopicPartition("topic", 0); |
||||||
|
private final Uuid topicId = Uuid.randomUuid(); |
||||||
|
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft"); |
||||||
|
|
||||||
|
@BeforeEach |
||||||
|
public void setupSupportedApis() { |
||||||
|
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS.stream().map( |
||||||
|
ApiVersionsResponse::toApiVersion).collect(Collectors.toList()); |
||||||
|
client.setNodeApiVersions(NodeApiVersions.create(supportedApis)); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testSendToUnknownDestination() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
assertBrokerNotAvailable(destinationId); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
client.backoff(destinationNode, 500); |
||||||
|
assertBrokerNotAvailable(destinationId); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testWakeupClientOnSend() throws InterruptedException, ExecutionException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
|
||||||
|
client.enableBlockingUntilWakeup(1); |
||||||
|
|
||||||
|
Thread ioThread = new Thread(() -> { |
||||||
|
// Block in poll until we get the expected wakeup
|
||||||
|
channel.pollOnce(); |
||||||
|
|
||||||
|
// Poll a second time to send request and receive response
|
||||||
|
channel.pollOnce(); |
||||||
|
}); |
||||||
|
|
||||||
|
AbstractResponse response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)); |
||||||
|
client.prepareResponseFrom(response, destinationNode, false); |
||||||
|
|
||||||
|
ioThread.start(); |
||||||
|
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId); |
||||||
|
|
||||||
|
ioThread.join(); |
||||||
|
assertResponseCompleted(request, Errors.INVALID_REQUEST); |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testSendAndDisconnect() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
|
||||||
|
for (ApiKeys apiKey : RAFT_APIS) { |
||||||
|
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)); |
||||||
|
client.prepareResponseFrom(response, destinationNode, true); |
||||||
|
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
|
||||||
|
for (ApiKeys apiKey : RAFT_APIS) { |
||||||
|
client.createPendingAuthenticationError(destinationNode, 100); |
||||||
|
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION); |
||||||
|
|
||||||
|
// reset to clear backoff time
|
||||||
|
client.reset(); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException { |
||||||
|
for (ApiKeys apiKey : RAFT_APIS) { |
||||||
|
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
|
||||||
|
for (ApiKeys apiKey : RAFT_APIS) { |
||||||
|
Errors expectedError = Errors.INVALID_REQUEST; |
||||||
|
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError)); |
||||||
|
client.prepareResponseFrom(response, destinationNode); |
||||||
|
System.out.println("api key " + apiKey + ", response " + response); |
||||||
|
sendAndAssertErrorResponse(apiKey, destinationId, expectedError); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@Test |
||||||
|
public void testUnsupportedVersionError() throws ExecutionException, InterruptedException { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
|
||||||
|
for (ApiKeys apiKey : RAFT_APIS) { |
||||||
|
client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey); |
||||||
|
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
@ParameterizedTest |
||||||
|
@ApiKeyVersionsSource(apiKey = ApiKeys.FETCH) |
||||||
|
public void testFetchRequestDowngrade(short version) { |
||||||
|
int destinationId = 2; |
||||||
|
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); |
||||||
|
channel.updateEndpoint(destinationId, new RaftConfig.InetAddressSpec( |
||||||
|
new InetSocketAddress(destinationNode.host(), destinationNode.port()))); |
||||||
|
sendTestRequest(ApiKeys.FETCH, destinationId); |
||||||
|
channel.pollOnce(); |
||||||
|
|
||||||
|
assertEquals(1, client.requests().size()); |
||||||
|
AbstractRequest request = client.requests().peek().requestBuilder().build(version); |
||||||
|
|
||||||
|
if (version < 15) { |
||||||
|
assertTrue(((FetchRequest) request).data().replicaId() == 1); |
||||||
|
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == -1); |
||||||
|
} else { |
||||||
|
assertTrue(((FetchRequest) request).data().replicaId() == -1); |
||||||
|
assertTrue(((FetchRequest) request).data().replicaState().replicaId() == 1); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) { |
||||||
|
int correlationId = channel.newCorrelationId(); |
||||||
|
long createdTimeMs = time.milliseconds(); |
||||||
|
ApiMessage apiRequest = buildTestRequest(apiKey); |
||||||
|
RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs); |
||||||
|
channel.send(request); |
||||||
|
return request; |
||||||
|
} |
||||||
|
|
||||||
|
private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException { |
||||||
|
assertTrue(request.completion.isDone()); |
||||||
|
|
||||||
|
RaftResponse.Inbound response = request.completion.get(); |
||||||
|
assertEquals(request.destinationId(), response.sourceId()); |
||||||
|
assertEquals(request.correlationId, response.correlationId); |
||||||
|
assertEquals(request.data.apiKey(), response.data.apiKey()); |
||||||
|
assertEquals(expectedError, extractError(response.data)); |
||||||
|
} |
||||||
|
|
||||||
|
private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException { |
||||||
|
RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId); |
||||||
|
channel.pollOnce(); |
||||||
|
assertResponseCompleted(request, error); |
||||||
|
} |
||||||
|
|
||||||
|
private ApiMessage buildTestRequest(ApiKeys key) { |
||||||
|
int leaderEpoch = 5; |
||||||
|
int leaderId = 1; |
||||||
|
switch (key) { |
||||||
|
case BEGIN_QUORUM_EPOCH: |
||||||
|
return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId); |
||||||
|
case END_QUORUM_EPOCH: |
||||||
|
return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch, |
||||||
|
Collections.singletonList(2)); |
||||||
|
case VOTE: |
||||||
|
int lastEpoch = 4; |
||||||
|
return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329); |
||||||
|
case FETCH: |
||||||
|
FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> { |
||||||
|
fetchPartition |
||||||
|
.setCurrentLeaderEpoch(5) |
||||||
|
.setFetchOffset(333) |
||||||
|
.setLastFetchedEpoch(5); |
||||||
|
}); |
||||||
|
request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1)); |
||||||
|
return request; |
||||||
|
default: |
||||||
|
throw new AssertionError("Unexpected api " + key); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private ApiMessage buildTestErrorResponse(ApiKeys key, Errors error) { |
||||||
|
switch (key) { |
||||||
|
case BEGIN_QUORUM_EPOCH: |
||||||
|
return new BeginQuorumEpochResponseData().setErrorCode(error.code()); |
||||||
|
case END_QUORUM_EPOCH: |
||||||
|
return new EndQuorumEpochResponseData().setErrorCode(error.code()); |
||||||
|
case VOTE: |
||||||
|
return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); |
||||||
|
case FETCH: |
||||||
|
return new FetchResponseData().setErrorCode(error.code()); |
||||||
|
default: |
||||||
|
throw new AssertionError("Unexpected api " + key); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
private Errors extractError(ApiMessage response) { |
||||||
|
short code; |
||||||
|
if (response instanceof BeginQuorumEpochResponseData) |
||||||
|
code = ((BeginQuorumEpochResponseData) response).errorCode(); |
||||||
|
else if (response instanceof EndQuorumEpochResponseData) |
||||||
|
code = ((EndQuorumEpochResponseData) response).errorCode(); |
||||||
|
else if (response instanceof FetchResponseData) |
||||||
|
code = ((FetchResponseData) response).errorCode(); |
||||||
|
else if (response instanceof VoteResponseData) |
||||||
|
code = ((VoteResponseData) response).errorCode(); |
||||||
|
else |
||||||
|
throw new IllegalArgumentException("Unexpected type for responseData: " + response); |
||||||
|
return Errors.forCode(code); |
||||||
|
} |
||||||
|
|
||||||
|
private AbstractResponse buildResponse(ApiMessage responseData) { |
||||||
|
if (responseData instanceof VoteResponseData) |
||||||
|
return new VoteResponse((VoteResponseData) responseData); |
||||||
|
if (responseData instanceof BeginQuorumEpochResponseData) |
||||||
|
return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData); |
||||||
|
if (responseData instanceof EndQuorumEpochResponseData) |
||||||
|
return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData); |
||||||
|
if (responseData instanceof FetchResponseData) |
||||||
|
return new FetchResponse((FetchResponseData) responseData); |
||||||
|
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); |
||||||
|
} |
||||||
|
} |
Loading…
Reference in new issue