Browse Source

KAFKA-6529: Stop file descriptor leak when client disconnects with staged receives (#4517)

If an exception is encountered while sending data to a client connection, that connection is disconnected. If there are staged receives for that connection, they are tracked to process those records. However, if the exception was encountered during processing a `RequestChannel.Request`, the `KafkaChannel` for that connection is muted and won't be processed.

Disable processing of outstanding staged receives if a send fails. This stops the leak of the memory for pending requests and the file descriptor of the TCP socket.

Test that a channel is closed when an exception is raised while writing to a socket that has been closed by the client. Since sending a response requires acks != 0, allow specifying the required acks for test requests in SocketServerTest.scala.

Author: Graham Campbell <graham.campbell@salesforce.com>

Reviewers: Jason Gustafson <jason@confluent.io>, Rajini Sivaram <rajinisivaram@googlemail.com>, Ismael Juma <ismael@juma.me.uk>, Ted Yu <yuzhihong@gmail.com>
pull/3876/merge
parafiend 7 years ago committed by Rajini Sivaram
parent
commit
fc56a90e05
  1. 29
      clients/src/main/java/org/apache/kafka/common/network/Selector.java
  2. 83
      core/src/test/scala/unit/kafka/network/SocketServerTest.scala

29
clients/src/main/java/org/apache/kafka/common/network/Selector.java

@ -325,9 +325,9 @@ public class Selector implements Selectable, AutoCloseable { @@ -325,9 +325,9 @@ public class Selector implements Selectable, AutoCloseable {
} catch (Exception e) {
// update the state for consistency, the channel will be discarded after `close`
channel.state(ChannelState.FAILED_SEND);
// ensure notification via `disconnected`
// ensure notification via `disconnected` when `failedSends` are processed in the next poll
this.failedSends.add(connectionId);
close(channel, false);
close(channel, false, false);
if (!(e instanceof CancelledKeyException)) {
log.error("Unexpected exception during send, closing connection {} and rethrowing exception {}",
connectionId, e);
@ -450,6 +450,7 @@ public class Selector implements Selectable, AutoCloseable { @@ -450,6 +450,7 @@ public class Selector implements Selectable, AutoCloseable {
if (idleExpiryManager != null)
idleExpiryManager.update(channel.id(), currentTimeNanos);
boolean sendFailed = false;
try {
/* complete any connections that have finished their handshake (either normally or immediately) */
@ -491,7 +492,13 @@ public class Selector implements Selectable, AutoCloseable { @@ -491,7 +492,13 @@ public class Selector implements Selectable, AutoCloseable {
/* if channel is ready write to any sockets that have space in their buffer and for which we have data */
if (channel.ready() && key.isWritable()) {
Send send = channel.write();
Send send = null;
try {
send = channel.write();
} catch (Exception e) {
sendFailed = true;
throw e;
}
if (send != null) {
this.completedSends.add(send);
this.sensors.recordBytesSent(channel.id(), send.size());
@ -500,7 +507,7 @@ public class Selector implements Selectable, AutoCloseable { @@ -500,7 +507,7 @@ public class Selector implements Selectable, AutoCloseable {
/* cancel any defunct sockets */
if (!key.isValid())
close(channel, true);
close(channel, true, true);
} catch (Exception e) {
String desc = channel.socketDescription();
@ -510,7 +517,7 @@ public class Selector implements Selectable, AutoCloseable { @@ -510,7 +517,7 @@ public class Selector implements Selectable, AutoCloseable {
log.debug("Connection with {} disconnected due to authentication exception", desc, e);
else
log.warn("Unexpected error from {}; closing connection", desc, e);
close(channel, true);
close(channel, !sendFailed, true);
} finally {
maybeRecordTimePerConnection(channel, channelStartTimeNanos);
}
@ -620,7 +627,7 @@ public class Selector implements Selectable, AutoCloseable { @@ -620,7 +627,7 @@ public class Selector implements Selectable, AutoCloseable {
log.trace("About to close the idle connection from {} due to being idle for {} millis",
connectionId, (currentTimeNanos - expiredConnection.getValue()) / 1000 / 1000);
channel.state(ChannelState.EXPIRED);
close(channel, true);
close(channel, true, true);
}
}
}
@ -674,7 +681,7 @@ public class Selector implements Selectable, AutoCloseable { @@ -674,7 +681,7 @@ public class Selector implements Selectable, AutoCloseable {
// There is no disconnect notification for local close, but updating
// channel state here anyway to avoid confusion.
channel.state(ChannelState.LOCAL_CLOSE);
close(channel, false);
close(channel, false, false);
} else {
KafkaChannel closingChannel = this.closingChannels.remove(id);
// Close any closing channel, leave the channel in the state in which closing was triggered
@ -694,7 +701,10 @@ public class Selector implements Selectable, AutoCloseable { @@ -694,7 +701,10 @@ public class Selector implements Selectable, AutoCloseable {
* closed immediately. The channel will not be added to disconnected list and it is the
* responsibility of the caller to handle disconnect notifications.
*/
private void close(KafkaChannel channel, boolean processOutstanding) {
private void close(KafkaChannel channel, boolean processOutstanding, boolean notifyDisconnect) {
if (processOutstanding && !notifyDisconnect)
throw new IllegalStateException("Disconnect notification required for remote disconnect after processing outstanding requests");
channel.disconnect();
@ -712,8 +722,9 @@ public class Selector implements Selectable, AutoCloseable { @@ -712,8 +722,9 @@ public class Selector implements Selectable, AutoCloseable {
if (processOutstanding && deque != null && !deque.isEmpty()) {
// stagedReceives will be moved to completedReceives later along with receives from other channels
closingChannels.put(channel.id(), channel);
log.debug("Tracking closing connection {} to process outstanding requests", channel.id());
} else
doClose(channel, processOutstanding);
doClose(channel, notifyDisconnect);
this.channels.remove(channel.id());
if (idleExpiryManager != null)

83
core/src/test/scala/unit/kafka/network/SocketServerTest.scala

@ -149,7 +149,7 @@ class SocketServerTest extends JUnitSuite { @@ -149,7 +149,7 @@ class SocketServerTest extends JUnitSuite {
}
def sendAndReceiveRequest(socket: Socket, server: SocketServer): RequestChannel.Request = {
sendRequest(socket, producerRequestBytes)
sendRequest(socket, producerRequestBytes())
receiveRequest(server.requestChannel)
}
@ -158,11 +158,10 @@ class SocketServerTest extends JUnitSuite { @@ -158,11 +158,10 @@ class SocketServerTest extends JUnitSuite {
server.metrics.close()
}
private def producerRequestBytes: Array[Byte] = {
private def producerRequestBytes(ack: Short = 0): Array[Byte] = {
val correlationId = -1
val clientId = ""
val ackTimeoutMs = 10000
val ack = 0: Short
val emptyRequest = ProduceRequest.Builder.forCurrentMagic(ack, ackTimeoutMs,
new HashMap[TopicPartition, MemoryRecords]()).build()
@ -178,7 +177,7 @@ class SocketServerTest extends JUnitSuite { @@ -178,7 +177,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def simpleRequest() {
val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT)
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
// Test PLAINTEXT socket
sendRequest(plainSocket, serializedBytes)
@ -207,7 +206,7 @@ class SocketServerTest extends JUnitSuite { @@ -207,7 +206,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testGracefulClose() {
val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT)
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
for (_ <- 0 until 10)
sendRequest(plainSocket, serializedBytes)
@ -222,7 +221,7 @@ class SocketServerTest extends JUnitSuite { @@ -222,7 +221,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testNoOpAction(): Unit = {
val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT)
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
for (_ <- 0 until 3)
sendRequest(plainSocket, serializedBytes)
@ -236,7 +235,7 @@ class SocketServerTest extends JUnitSuite { @@ -236,7 +235,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testConnectionId() {
val sockets = (1 to 5).map(_ => connect(protocol = SecurityProtocol.PLAINTEXT))
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
val requests = sockets.map{socket =>
sendRequest(socket, serializedBytes)
@ -265,7 +264,7 @@ class SocketServerTest extends JUnitSuite { @@ -265,7 +264,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer.startup()
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
// Connection with no staged receives
val socket1 = connect(overrideServer, protocol = SecurityProtocol.PLAINTEXT)
@ -348,7 +347,7 @@ class SocketServerTest extends JUnitSuite { @@ -348,7 +347,7 @@ class SocketServerTest extends JUnitSuite {
// Send requests to `channel1` until a receive is staged and advance time beyond idle time so that `channel1` is
// closed with staged receives and is in Selector.closingChannels
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
val request = sendRequestsUntilStagedReceive(overrideServer, socket1, serializedBytes)
time.sleep(idleTimeMs + 1)
TestUtils.waitUntilTrue(() => openChannel.isEmpty, "Idle channel not closed")
@ -438,7 +437,7 @@ class SocketServerTest extends JUnitSuite { @@ -438,7 +437,7 @@ class SocketServerTest extends JUnitSuite {
TestUtils.waitUntilTrue(() => server.connectionCount(address) < conns.length,
"Failed to decrement connection count after close")
val conn2 = connect()
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
sendRequest(conn2, serializedBytes)
val request = server.requestChannel.receiveRequest(2000)
assertNotNull(request)
@ -457,7 +456,7 @@ class SocketServerTest extends JUnitSuite { @@ -457,7 +456,7 @@ class SocketServerTest extends JUnitSuite {
val conns = (0 until overrideNum).map(_ => connect(overrideServer))
// it should succeed
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
sendRequest(conns.last, serializedBytes)
val request = overrideServer.requestChannel.receiveRequest(2000)
assertNotNull(request)
@ -540,7 +539,7 @@ class SocketServerTest extends JUnitSuite { @@ -540,7 +539,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer.startup()
conn = connect(overrideServer)
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
sendRequest(conn, serializedBytes)
val channel = overrideServer.requestChannel
@ -565,6 +564,54 @@ class SocketServerTest extends JUnitSuite { @@ -565,6 +564,54 @@ class SocketServerTest extends JUnitSuite {
}
}
@Test
def testClientDisconnectionWithStagedReceivesFullyProcessed() {
val serverMetrics = new Metrics
@volatile var selector: TestableSelector = null
val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0"
val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider) {
override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) {
override protected[network] def connectionId(socket: Socket): String = overrideConnectionId
override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = {
val testableSelector = new TestableSelector(config, channelBuilder, time, metrics)
selector = testableSelector
testableSelector
}
}
}
}
def openChannel: Option[KafkaChannel] = overrideServer.processor(0).channel(overrideConnectionId)
def openOrClosingChannel: Option[KafkaChannel] = overrideServer.processor(0).openOrClosingChannel(overrideConnectionId)
try {
overrideServer.startup()
val socket = connect(overrideServer)
TestUtils.waitUntilTrue(() => openChannel.nonEmpty, "Channel not found")
// Setup channel to client with staged receives so when client disconnects
// it will be stored in Selector.closingChannels
val serializedBytes = producerRequestBytes(1)
val request = sendRequestsUntilStagedReceive(overrideServer, socket, serializedBytes)
// Set SoLinger to 0 to force a hard disconnect via TCP RST
socket.setSoLinger(true, 0)
socket.close()
// Complete request with socket exception so that the channel is removed from Selector.closingChannels
processRequest(overrideServer.requestChannel, request)
TestUtils.waitUntilTrue(() => openOrClosingChannel.isEmpty, "Channel not closed after failed send")
assertTrue("Unexpected completed send", selector.completedSends.isEmpty)
} finally {
overrideServer.shutdown()
serverMetrics.close()
}
}
/*
* Test that we update request metrics if the channel has been removed from the selector when the broker calls
* `selector.send` (selector closes old connections, for example).
@ -579,7 +626,7 @@ class SocketServerTest extends JUnitSuite { @@ -579,7 +626,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer.startup()
conn = connect(overrideServer)
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
sendRequest(conn, serializedBytes)
val channel = overrideServer.requestChannel
val request = receiveRequest(channel)
@ -698,7 +745,7 @@ class SocketServerTest extends JUnitSuite { @@ -698,7 +745,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector.updateMinWakeup(2)
val sockets = (1 to 2).map(_ => connect(testableServer))
sockets.foreach(sendRequest(_, producerRequestBytes))
sockets.foreach(sendRequest(_, producerRequestBytes()))
testableServer.testableSelector.addFailure(SelectorOperation.Send)
sockets.foreach(_ => processRequest(testableServer.requestChannel))
@ -721,7 +768,7 @@ class SocketServerTest extends JUnitSuite { @@ -721,7 +768,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector.updateMinWakeup(2)
val sockets = (1 to 2).map(_ => connect(testableServer))
sockets.foreach(sendRequest(_, producerRequestBytes))
sockets.foreach(sendRequest(_, producerRequestBytes()))
val requestChannel = testableServer.requestChannel
val requests = sockets.map(_ => receiveRequest(requestChannel))
@ -749,7 +796,7 @@ class SocketServerTest extends JUnitSuite { @@ -749,7 +796,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector.updateMinWakeup(2)
val sockets = (1 to 2).map(_ => connect(testableServer))
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes()
val request = sendRequestsUntilStagedReceive(testableServer, sockets(0), serializedBytes)
sendRequest(sockets(1), serializedBytes)
@ -783,7 +830,7 @@ class SocketServerTest extends JUnitSuite { @@ -783,7 +830,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector.cachedCompletedReceives.minPerPoll = 2
testableSelector.addFailure(SelectorOperation.Mute)
sockets.foreach(sendRequest(_, producerRequestBytes))
sockets.foreach(sendRequest(_, producerRequestBytes()))
val requests = sockets.map(_ => receiveRequest(requestChannel))
testableSelector.waitForOperations(SelectorOperation.Mute, 2)
testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
@ -907,7 +954,7 @@ class SocketServerTest extends JUnitSuite { @@ -907,7 +954,7 @@ class SocketServerTest extends JUnitSuite {
// Check new channel behaves as expected
val (socket, connectionId) = connectAndProcessRequest(testableServer)
assertArrayEquals(producerRequestBytes, receiveResponse(socket))
assertArrayEquals(producerRequestBytes(), receiveResponse(socket))
assertNotNull("Channel should not have been closed", selector.channel(connectionId))
assertNull("Channel should not be closing", selector.closingChannel(connectionId))
socket.close()

Loading…
Cancel
Save