diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 6bfcfd21a90..ed037b3a8f7 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -325,9 +325,9 @@ public class Selector implements Selectable, AutoCloseable { } catch (Exception e) { // update the state for consistency, the channel will be discarded after `close` channel.state(ChannelState.FAILED_SEND); - // ensure notification via `disconnected` + // ensure notification via `disconnected` when `failedSends` are processed in the next poll this.failedSends.add(connectionId); - close(channel, false); + close(channel, false, false); if (!(e instanceof CancelledKeyException)) { log.error("Unexpected exception during send, closing connection {} and rethrowing exception {}", connectionId, e); @@ -450,6 +450,7 @@ public class Selector implements Selectable, AutoCloseable { if (idleExpiryManager != null) idleExpiryManager.update(channel.id(), currentTimeNanos); + boolean sendFailed = false; try { /* complete any connections that have finished their handshake (either normally or immediately) */ @@ -491,7 +492,13 @@ public class Selector implements Selectable, AutoCloseable { /* if channel is ready write to any sockets that have space in their buffer and for which we have data */ if (channel.ready() && key.isWritable()) { - Send send = channel.write(); + Send send = null; + try { + send = channel.write(); + } catch (Exception e) { + sendFailed = true; + throw e; + } if (send != null) { this.completedSends.add(send); this.sensors.recordBytesSent(channel.id(), send.size()); @@ -500,7 +507,7 @@ public class Selector implements Selectable, AutoCloseable { /* cancel any defunct sockets */ if (!key.isValid()) - close(channel, true); + close(channel, true, true); } catch (Exception e) { String desc = channel.socketDescription(); @@ -510,7 +517,7 @@ public class Selector implements Selectable, AutoCloseable { log.debug("Connection with {} disconnected due to authentication exception", desc, e); else log.warn("Unexpected error from {}; closing connection", desc, e); - close(channel, true); + close(channel, !sendFailed, true); } finally { maybeRecordTimePerConnection(channel, channelStartTimeNanos); } @@ -620,7 +627,7 @@ public class Selector implements Selectable, AutoCloseable { log.trace("About to close the idle connection from {} due to being idle for {} millis", connectionId, (currentTimeNanos - expiredConnection.getValue()) / 1000 / 1000); channel.state(ChannelState.EXPIRED); - close(channel, true); + close(channel, true, true); } } } @@ -674,7 +681,7 @@ public class Selector implements Selectable, AutoCloseable { // There is no disconnect notification for local close, but updating // channel state here anyway to avoid confusion. channel.state(ChannelState.LOCAL_CLOSE); - close(channel, false); + close(channel, false, false); } else { KafkaChannel closingChannel = this.closingChannels.remove(id); // Close any closing channel, leave the channel in the state in which closing was triggered @@ -694,7 +701,10 @@ public class Selector implements Selectable, AutoCloseable { * closed immediately. The channel will not be added to disconnected list and it is the * responsibility of the caller to handle disconnect notifications. */ - private void close(KafkaChannel channel, boolean processOutstanding) { + private void close(KafkaChannel channel, boolean processOutstanding, boolean notifyDisconnect) { + + if (processOutstanding && !notifyDisconnect) + throw new IllegalStateException("Disconnect notification required for remote disconnect after processing outstanding requests"); channel.disconnect(); @@ -712,8 +722,9 @@ public class Selector implements Selectable, AutoCloseable { if (processOutstanding && deque != null && !deque.isEmpty()) { // stagedReceives will be moved to completedReceives later along with receives from other channels closingChannels.put(channel.id(), channel); + log.debug("Tracking closing connection {} to process outstanding requests", channel.id()); } else - doClose(channel, processOutstanding); + doClose(channel, notifyDisconnect); this.channels.remove(channel.id()); if (idleExpiryManager != null) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 724e05b479b..a057e54bd7e 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -149,7 +149,7 @@ class SocketServerTest extends JUnitSuite { } def sendAndReceiveRequest(socket: Socket, server: SocketServer): RequestChannel.Request = { - sendRequest(socket, producerRequestBytes) + sendRequest(socket, producerRequestBytes()) receiveRequest(server.requestChannel) } @@ -158,11 +158,10 @@ class SocketServerTest extends JUnitSuite { server.metrics.close() } - private def producerRequestBytes: Array[Byte] = { + private def producerRequestBytes(ack: Short = 0): Array[Byte] = { val correlationId = -1 val clientId = "" val ackTimeoutMs = 10000 - val ack = 0: Short val emptyRequest = ProduceRequest.Builder.forCurrentMagic(ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() @@ -178,7 +177,7 @@ class SocketServerTest extends JUnitSuite { @Test def simpleRequest() { val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() // Test PLAINTEXT socket sendRequest(plainSocket, serializedBytes) @@ -207,7 +206,7 @@ class SocketServerTest extends JUnitSuite { @Test def testGracefulClose() { val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() for (_ <- 0 until 10) sendRequest(plainSocket, serializedBytes) @@ -222,7 +221,7 @@ class SocketServerTest extends JUnitSuite { @Test def testNoOpAction(): Unit = { val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() for (_ <- 0 until 3) sendRequest(plainSocket, serializedBytes) @@ -236,7 +235,7 @@ class SocketServerTest extends JUnitSuite { @Test def testConnectionId() { val sockets = (1 to 5).map(_ => connect(protocol = SecurityProtocol.PLAINTEXT)) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() val requests = sockets.map{socket => sendRequest(socket, serializedBytes) @@ -265,7 +264,7 @@ class SocketServerTest extends JUnitSuite { try { overrideServer.startup() - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() // Connection with no staged receives val socket1 = connect(overrideServer, protocol = SecurityProtocol.PLAINTEXT) @@ -348,7 +347,7 @@ class SocketServerTest extends JUnitSuite { // Send requests to `channel1` until a receive is staged and advance time beyond idle time so that `channel1` is // closed with staged receives and is in Selector.closingChannels - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() val request = sendRequestsUntilStagedReceive(overrideServer, socket1, serializedBytes) time.sleep(idleTimeMs + 1) TestUtils.waitUntilTrue(() => openChannel.isEmpty, "Idle channel not closed") @@ -438,7 +437,7 @@ class SocketServerTest extends JUnitSuite { TestUtils.waitUntilTrue(() => server.connectionCount(address) < conns.length, "Failed to decrement connection count after close") val conn2 = connect() - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn2, serializedBytes) val request = server.requestChannel.receiveRequest(2000) assertNotNull(request) @@ -457,7 +456,7 @@ class SocketServerTest extends JUnitSuite { val conns = (0 until overrideNum).map(_ => connect(overrideServer)) // it should succeed - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conns.last, serializedBytes) val request = overrideServer.requestChannel.receiveRequest(2000) assertNotNull(request) @@ -540,7 +539,7 @@ class SocketServerTest extends JUnitSuite { try { overrideServer.startup() conn = connect(overrideServer) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) val channel = overrideServer.requestChannel @@ -565,6 +564,54 @@ class SocketServerTest extends JUnitSuite { } } + @Test + def testClientDisconnectionWithStagedReceivesFullyProcessed() { + val serverMetrics = new Metrics + @volatile var selector: TestableSelector = null + val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0" + val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider) { + override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, + config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) { + override protected[network] def connectionId(socket: Socket): String = overrideConnectionId + override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { + val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) + selector = testableSelector + testableSelector + } + } + } + } + + def openChannel: Option[KafkaChannel] = overrideServer.processor(0).channel(overrideConnectionId) + def openOrClosingChannel: Option[KafkaChannel] = overrideServer.processor(0).openOrClosingChannel(overrideConnectionId) + + try { + overrideServer.startup() + val socket = connect(overrideServer) + + TestUtils.waitUntilTrue(() => openChannel.nonEmpty, "Channel not found") + + // Setup channel to client with staged receives so when client disconnects + // it will be stored in Selector.closingChannels + val serializedBytes = producerRequestBytes(1) + val request = sendRequestsUntilStagedReceive(overrideServer, socket, serializedBytes) + + // Set SoLinger to 0 to force a hard disconnect via TCP RST + socket.setSoLinger(true, 0) + socket.close() + + // Complete request with socket exception so that the channel is removed from Selector.closingChannels + processRequest(overrideServer.requestChannel, request) + TestUtils.waitUntilTrue(() => openOrClosingChannel.isEmpty, "Channel not closed after failed send") + assertTrue("Unexpected completed send", selector.completedSends.isEmpty) + } finally { + overrideServer.shutdown() + serverMetrics.close() + } + } + /* * Test that we update request metrics if the channel has been removed from the selector when the broker calls * `selector.send` (selector closes old connections, for example). @@ -579,7 +626,7 @@ class SocketServerTest extends JUnitSuite { try { overrideServer.startup() conn = connect(overrideServer) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) val channel = overrideServer.requestChannel val request = receiveRequest(channel) @@ -698,7 +745,7 @@ class SocketServerTest extends JUnitSuite { testableSelector.updateMinWakeup(2) val sockets = (1 to 2).map(_ => connect(testableServer)) - sockets.foreach(sendRequest(_, producerRequestBytes)) + sockets.foreach(sendRequest(_, producerRequestBytes())) testableServer.testableSelector.addFailure(SelectorOperation.Send) sockets.foreach(_ => processRequest(testableServer.requestChannel)) @@ -721,7 +768,7 @@ class SocketServerTest extends JUnitSuite { testableSelector.updateMinWakeup(2) val sockets = (1 to 2).map(_ => connect(testableServer)) - sockets.foreach(sendRequest(_, producerRequestBytes)) + sockets.foreach(sendRequest(_, producerRequestBytes())) val requestChannel = testableServer.requestChannel val requests = sockets.map(_ => receiveRequest(requestChannel)) @@ -749,7 +796,7 @@ class SocketServerTest extends JUnitSuite { testableSelector.updateMinWakeup(2) val sockets = (1 to 2).map(_ => connect(testableServer)) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() val request = sendRequestsUntilStagedReceive(testableServer, sockets(0), serializedBytes) sendRequest(sockets(1), serializedBytes) @@ -783,7 +830,7 @@ class SocketServerTest extends JUnitSuite { testableSelector.cachedCompletedReceives.minPerPoll = 2 testableSelector.addFailure(SelectorOperation.Mute) - sockets.foreach(sendRequest(_, producerRequestBytes)) + sockets.foreach(sendRequest(_, producerRequestBytes())) val requests = sockets.map(_ => receiveRequest(requestChannel)) testableSelector.waitForOperations(SelectorOperation.Mute, 2) testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true) @@ -907,7 +954,7 @@ class SocketServerTest extends JUnitSuite { // Check new channel behaves as expected val (socket, connectionId) = connectAndProcessRequest(testableServer) - assertArrayEquals(producerRequestBytes, receiveResponse(socket)) + assertArrayEquals(producerRequestBytes(), receiveResponse(socket)) assertNotNull("Channel should not have been closed", selector.channel(connectionId)) assertNull("Channel should not be closing", selector.closingChannel(connectionId)) socket.close()