diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java index 4d01cdeb2e2..d22b508cd87 100644 --- a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java +++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java @@ -390,7 +390,7 @@ public class NetworkClient implements KafkaClient { } /** - * Iterate over all the inflight requests and expire any requests that have exceeded the configured the requestTimeout. + * Iterate over all the inflight requests and expire any requests that have exceeded the configured requestTimeout. * The connection to the node associated with the request will be terminated and will be treated as a disconnection. * * @param responses The list of responses to update diff --git a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala index b376d15e4eb..e9731fd4e09 100755 --- a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala +++ b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala @@ -178,9 +178,7 @@ class RequestSendThread(val controllerId: Int, val requestHeader = apiVersion.fold(networkClient.nextRequestHeader(apiKey))(networkClient.nextRequestHeader(apiKey, _)) val send = new RequestSend(brokerNode.idString, requestHeader, request.toStruct) val clientRequest = new ClientRequest(time.milliseconds(), true, send, null) - clientResponse = networkClient.blockingSendAndReceive(clientRequest, socketTimeoutMs)(time).getOrElse { - throw new SocketTimeoutException(s"No response received within $socketTimeoutMs ms") - } + clientResponse = networkClient.blockingSendAndReceive(clientRequest)(time) isSendSuccessful = true } } catch { diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index e29494baa1d..f998d82104d 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -320,9 +320,6 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime, threadNamePr val socketTimeoutMs = config.controllerSocketTimeoutMs - def socketTimeoutException: Throwable = - new SocketTimeoutException(s"Did not receive response within $socketTimeoutMs") - def networkClientControlledShutdown(retries: Int): Boolean = { val metadataUpdater = new ManualMetadataUpdater() val networkClient = { @@ -388,16 +385,14 @@ class KafkaServer(val config: KafkaConfig, time: Time = SystemTime, threadNamePr try { if (!networkClient.blockingReady(node(prevController), socketTimeoutMs)) - throw socketTimeoutException + throw new SocketTimeoutException(s"Failed to connect within $socketTimeoutMs ms") // send the controlled shutdown request val requestHeader = networkClient.nextRequestHeader(ApiKeys.CONTROLLED_SHUTDOWN_KEY) val send = new RequestSend(node(prevController).idString, requestHeader, new ControlledShutdownRequest(config.brokerId).toStruct) val request = new ClientRequest(kafkaMetricsTime.milliseconds(), true, send, null) - val clientResponse = networkClient.blockingSendAndReceive(request, socketTimeoutMs).getOrElse { - throw socketTimeoutException - } + val clientResponse = networkClient.blockingSendAndReceive(request) val shutdownResponse = new ControlledShutdownResponse(clientResponse.responseBody) if (shutdownResponse.errorCode == Errors.NONE.code && shutdownResponse.partitionsRemaining.isEmpty) { diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala index de7269f8332..26838cac96d 100644 --- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -233,9 +233,7 @@ class ReplicaFetcherThread(name: String, else { val send = new RequestSend(sourceBroker.id.toString, header, request.toStruct) val clientRequest = new ClientRequest(time.milliseconds(), true, send, null) - networkClient.blockingSendAndReceive(clientRequest, socketTimeout)(time).getOrElse { - throw new SocketTimeoutException(s"No response received within $socketTimeout ms") - } + networkClient.blockingSendAndReceive(clientRequest)(time) } } catch { diff --git a/core/src/main/scala/kafka/utils/NetworkClientBlockingOps.scala b/core/src/main/scala/kafka/utils/NetworkClientBlockingOps.scala index 9ed9d29a293..fd4af6e949b 100644 --- a/core/src/main/scala/kafka/utils/NetworkClientBlockingOps.scala +++ b/core/src/main/scala/kafka/utils/NetworkClientBlockingOps.scala @@ -55,6 +55,7 @@ class NetworkClientBlockingOps(val client: NetworkClient) extends AnyVal { * care. */ def blockingReady(node: Node, timeout: Long)(implicit time: JTime): Boolean = { + require(timeout >=0, "timeout should be >= 0") client.ready(node, time.milliseconds()) || pollUntil(timeout) { (_, now) => if (client.isReady(node, now)) true @@ -65,19 +66,18 @@ class NetworkClientBlockingOps(val client: NetworkClient) extends AnyVal { } /** - * Invokes `client.send` followed by 1 or more `client.poll` invocations until a response is received, - * the timeout expires or a disconnection happens. + * Invokes `client.send` followed by 1 or more `client.poll` invocations until a response is received or a + * disconnection happens (which can happen for a number of reasons including a request timeout). * - * It returns `true` if the call completes normally or `false` if the timeout expires. In the case of a disconnection, - * an `IOException` is thrown instead. + * In case of a disconnection, an `IOException` is thrown. * * This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with * care. */ - def blockingSendAndReceive(request: ClientRequest, timeout: Long)(implicit time: JTime): Option[ClientResponse] = { + def blockingSendAndReceive(request: ClientRequest)(implicit time: JTime): ClientResponse = { client.send(request, time.milliseconds()) - pollUntilFound(timeout) { case (responses, _) => + pollContinuously { responses => val response = responses.find { response => response.request.request.header.correlationId == request.request.header.correlationId } @@ -102,41 +102,45 @@ class NetworkClientBlockingOps(val client: NetworkClient) extends AnyVal { * care. */ private def pollUntil(timeout: Long)(predicate: (Seq[ClientResponse], Long) => Boolean)(implicit time: JTime): Boolean = { - pollUntilFound(timeout) { (responses, now) => - if (predicate(responses, now)) Some(true) - else None - }.fold(false)(_ => true) - } - - /** - * Invokes `client.poll` until `collect` returns `Some` or the timeout expires. - * - * It returns the result of `collect` if the call completes normally or `None` if the timeout expires. Exceptions - * thrown via `collect` are not handled and will bubble up. - * - * This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with - * care. - */ - private def pollUntilFound[T](timeout: Long)(collect: (Seq[ClientResponse], Long) => Option[T])(implicit time: JTime): Option[T] = { - val methodStartTime = time.milliseconds() val timeoutExpiryTime = methodStartTime + timeout @tailrec - def recurse(iterationStartTime: Long): Option[T] = { - val pollTimeout = if (timeout < 0) timeout else timeoutExpiryTime - iterationStartTime + def recursivePoll(iterationStartTime: Long): Boolean = { + val pollTimeout = timeoutExpiryTime - iterationStartTime val responses = client.poll(pollTimeout, iterationStartTime).asScala - val result = collect(responses, iterationStartTime) - if (result.isDefined) result + if (predicate(responses, iterationStartTime)) true else { val afterPollTime = time.milliseconds() - if (timeout < 0 || afterPollTime < timeoutExpiryTime) - recurse(afterPollTime) - else None + if (afterPollTime < timeoutExpiryTime) recursivePoll(afterPollTime) + else false + } + } + + recursivePoll(methodStartTime) + } + + /** + * Invokes `client.poll` until `collect` returns `Some`. The value inside `Some` is returned. + * + * Exceptions thrown via `collect` are not handled and will bubble up. + * + * This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with + * care. + */ + private def pollContinuously[T](collect: Seq[ClientResponse] => Option[T])(implicit time: JTime): T = { + + @tailrec + def recursivePoll: T = { + // rely on request timeout to ensure we don't block forever + val responses = client.poll(Long.MaxValue, time.milliseconds()).asScala + collect(responses) match { + case Some(result) => result + case None => recursivePoll } } - recurse(methodStartTime) + recursivePoll } }