@ -149,7 +149,7 @@ class SocketServerTest extends JUnitSuite {
@@ -149,7 +149,7 @@ class SocketServerTest extends JUnitSuite {
}
def sendAndReceiveRequest ( socket : Socket , server : SocketServer ) : RequestChannel . Request = {
sendRequest ( socket , producerRequestBytes )
sendRequest ( socket , producerRequestBytes ( ) )
receiveRequest ( server . requestChannel )
}
@ -158,11 +158,10 @@ class SocketServerTest extends JUnitSuite {
@@ -158,11 +158,10 @@ class SocketServerTest extends JUnitSuite {
server . metrics . close ( )
}
private def producerRequestBytes : Array [ Byte ] = {
private def producerRequestBytes ( ack : Short = 0 ) : Array [ Byte ] = {
val correlationId = - 1
val clientId = ""
val ackTimeoutMs = 10000
val ack = 0 : Short
val emptyRequest = ProduceRequest . Builder . forCurrentMagic ( ack , ackTimeoutMs ,
new HashMap [ TopicPartition , MemoryRecords ] ( ) ) . build ( )
@ -178,7 +177,7 @@ class SocketServerTest extends JUnitSuite {
@@ -178,7 +177,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def simpleRequest ( ) {
val plainSocket = connect ( protocol = SecurityProtocol . PLAINTEXT )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
// Test PLAINTEXT socket
sendRequest ( plainSocket , serializedBytes )
@ -207,7 +206,7 @@ class SocketServerTest extends JUnitSuite {
@@ -207,7 +206,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testGracefulClose ( ) {
val plainSocket = connect ( protocol = SecurityProtocol . PLAINTEXT )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
for ( _ <- 0 until 10 )
sendRequest ( plainSocket , serializedBytes )
@ -222,7 +221,7 @@ class SocketServerTest extends JUnitSuite {
@@ -222,7 +221,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testNoOpAction ( ) : Unit = {
val plainSocket = connect ( protocol = SecurityProtocol . PLAINTEXT )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
for ( _ <- 0 until 3 )
sendRequest ( plainSocket , serializedBytes )
@ -236,7 +235,7 @@ class SocketServerTest extends JUnitSuite {
@@ -236,7 +235,7 @@ class SocketServerTest extends JUnitSuite {
@Test
def testConnectionId ( ) {
val sockets = ( 1 to 5 ) . map ( _ => connect ( protocol = SecurityProtocol . PLAINTEXT ) )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
val requests = sockets . map { socket =>
sendRequest ( socket , serializedBytes )
@ -265,7 +264,7 @@ class SocketServerTest extends JUnitSuite {
@@ -265,7 +264,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer . startup ( )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
// Connection with no staged receives
val socket1 = connect ( overrideServer , protocol = SecurityProtocol . PLAINTEXT )
@ -348,7 +347,7 @@ class SocketServerTest extends JUnitSuite {
@@ -348,7 +347,7 @@ class SocketServerTest extends JUnitSuite {
// Send requests to `channel1` until a receive is staged and advance time beyond idle time so that `channel1` is
// closed with staged receives and is in Selector . closingChannels
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
val request = sendRequestsUntilStagedReceive ( overrideServer , socket1 , serializedBytes )
time . sleep ( idleTimeMs + 1 )
TestUtils . waitUntilTrue ( ( ) => openChannel . isEmpty , "Idle channel not closed" )
@ -438,7 +437,7 @@ class SocketServerTest extends JUnitSuite {
@@ -438,7 +437,7 @@ class SocketServerTest extends JUnitSuite {
TestUtils . waitUntilTrue ( ( ) => server . connectionCount ( address ) < conns . length ,
"Failed to decrement connection count after close" )
val conn2 = connect ( )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
sendRequest ( conn2 , serializedBytes )
val request = server . requestChannel . receiveRequest ( 2000 )
assertNotNull ( request )
@ -457,7 +456,7 @@ class SocketServerTest extends JUnitSuite {
@@ -457,7 +456,7 @@ class SocketServerTest extends JUnitSuite {
val conns = ( 0 until overrideNum ) . map ( _ => connect ( overrideServer ) )
// it should succeed
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
sendRequest ( conns . last , serializedBytes )
val request = overrideServer . requestChannel . receiveRequest ( 2000 )
assertNotNull ( request )
@ -540,7 +539,7 @@ class SocketServerTest extends JUnitSuite {
@@ -540,7 +539,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer . startup ( )
conn = connect ( overrideServer )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
sendRequest ( conn , serializedBytes )
val channel = overrideServer . requestChannel
@ -565,6 +564,54 @@ class SocketServerTest extends JUnitSuite {
@@ -565,6 +564,54 @@ class SocketServerTest extends JUnitSuite {
}
}
@Test
def testClientDisconnectionWithStagedReceivesFullyProcessed ( ) {
val serverMetrics = new Metrics
@volatile var selector : TestableSelector = null
val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0"
val overrideServer = new SocketServer ( KafkaConfig . fromProps ( props ) , serverMetrics , Time . SYSTEM , credentialProvider ) {
override def newProcessor ( id : Int , connectionQuotas : ConnectionQuotas , listenerName : ListenerName ,
protocol : SecurityProtocol , memoryPool : MemoryPool ) : Processor = {
new Processor ( id , time , config . socketRequestMaxBytes , requestChannel , connectionQuotas ,
config . connectionsMaxIdleMs , listenerName , protocol , config , metrics , credentialProvider , memoryPool , new LogContext ( ) ) {
override protected [ network ] def connectionId ( socket : Socket ) : String = overrideConnectionId
override protected [ network ] def createSelector ( channelBuilder : ChannelBuilder ) : Selector = {
val testableSelector = new TestableSelector ( config , channelBuilder , time , metrics )
selector = testableSelector
testableSelector
}
}
}
}
def openChannel : Option [ KafkaChannel ] = overrideServer . processor ( 0 ) . channel ( overrideConnectionId )
def openOrClosingChannel : Option [ KafkaChannel ] = overrideServer . processor ( 0 ) . openOrClosingChannel ( overrideConnectionId )
try {
overrideServer . startup ( )
val socket = connect ( overrideServer )
TestUtils . waitUntilTrue ( ( ) => openChannel . nonEmpty , "Channel not found" )
// Setup channel to client with staged receives so when client disconnects
// it will be stored in Selector . closingChannels
val serializedBytes = producerRequestBytes ( 1 )
val request = sendRequestsUntilStagedReceive ( overrideServer , socket , serializedBytes )
// Set SoLinger to 0 to force a hard disconnect via TCP RST
socket . setSoLinger ( true , 0 )
socket . close ( )
// Complete request with socket exception so that the channel is removed from Selector . closingChannels
processRequest ( overrideServer . requestChannel , request )
TestUtils . waitUntilTrue ( ( ) => openOrClosingChannel . isEmpty , "Channel not closed after failed send" )
assertTrue ( "Unexpected completed send" , selector . completedSends . isEmpty )
} finally {
overrideServer . shutdown ( )
serverMetrics . close ( )
}
}
/*
* Test that we update request metrics if the channel has been removed from the selector when the broker calls
* `selector.send` ( selector closes old connections , for example ) .
@ -579,7 +626,7 @@ class SocketServerTest extends JUnitSuite {
@@ -579,7 +626,7 @@ class SocketServerTest extends JUnitSuite {
try {
overrideServer . startup ( )
conn = connect ( overrideServer )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
sendRequest ( conn , serializedBytes )
val channel = overrideServer . requestChannel
val request = receiveRequest ( channel )
@ -698,7 +745,7 @@ class SocketServerTest extends JUnitSuite {
@@ -698,7 +745,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector . updateMinWakeup ( 2 )
val sockets = ( 1 to 2 ) . map ( _ => connect ( testableServer ) )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ) )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ( ) ) )
testableServer . testableSelector . addFailure ( SelectorOperation . Send )
sockets . foreach ( _ => processRequest ( testableServer . requestChannel ) )
@ -721,7 +768,7 @@ class SocketServerTest extends JUnitSuite {
@@ -721,7 +768,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector . updateMinWakeup ( 2 )
val sockets = ( 1 to 2 ) . map ( _ => connect ( testableServer ) )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ) )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ( ) ) )
val requestChannel = testableServer . requestChannel
val requests = sockets . map ( _ => receiveRequest ( requestChannel ) )
@ -749,7 +796,7 @@ class SocketServerTest extends JUnitSuite {
@@ -749,7 +796,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector . updateMinWakeup ( 2 )
val sockets = ( 1 to 2 ) . map ( _ => connect ( testableServer ) )
val serializedBytes = producerRequestBytes
val serializedBytes = producerRequestBytes ( )
val request = sendRequestsUntilStagedReceive ( testableServer , sockets ( 0 ) , serializedBytes )
sendRequest ( sockets ( 1 ) , serializedBytes )
@ -783,7 +830,7 @@ class SocketServerTest extends JUnitSuite {
@@ -783,7 +830,7 @@ class SocketServerTest extends JUnitSuite {
testableSelector . cachedCompletedReceives . minPerPoll = 2
testableSelector . addFailure ( SelectorOperation . Mute )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ) )
sockets . foreach ( sendRequest ( _ , producerRequestBytes ( ) ) )
val requests = sockets . map ( _ => receiveRequest ( requestChannel ) )
testableSelector . waitForOperations ( SelectorOperation . Mute , 2 )
testableServer . waitForChannelClose ( testableSelector . allFailedChannels . head , locallyClosed = true )
@ -907,7 +954,7 @@ class SocketServerTest extends JUnitSuite {
@@ -907,7 +954,7 @@ class SocketServerTest extends JUnitSuite {
// Check new channel behaves as expected
val ( socket , connectionId ) = connectAndProcessRequest ( testableServer )
assertArrayEquals ( producerRequestBytes , receiveResponse ( socket ) )
assertArrayEquals ( producerRequestBytes ( ) , receiveResponse ( socket ) )
assertNotNull ( "Channel should not have been closed" , selector . channel ( connectionId ) )
assertNull ( "Channel should not be closing" , selector . closingChannel ( connectionId ) )
socket . close ( )