diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala index 383aa09cd76..cb1fd80876e 100644 --- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala +++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala @@ -30,7 +30,7 @@ import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrParti import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} import org.apache.kafka.common.utils.Utils -import org.junit.Assert.{assertFalse, assertTrue} +import org.junit.Assert.{assertEquals, assertFalse, assertTrue} import org.junit.{After, Before, Test} import org.mockito.ArgumentMatchers import org.mockito.Mockito.{mock, when} @@ -57,11 +57,11 @@ class PartitionLockTest extends Logging { val executorService = Executors.newFixedThreadPool(numReplicaFetchers + numProducers + 1) val appendSemaphore = new Semaphore(0) val shrinkIsrSemaphore = new Semaphore(0) + val followerQueues = (0 until numReplicaFetchers).map(_ => new ArrayBlockingQueue[MemoryRecords](2)) var logManager: LogManager = _ var partition: Partition = _ - @Before def setUp(): Unit = { val logConfig = new LogConfig(new Properties) @@ -82,7 +82,7 @@ class PartitionLockTest extends Logging { */ @Test def testNoLockContentionWithoutIsrUpdate(): Unit = { - concurrentProduceFetch(numProducers, numReplicaFetchers, numRecordsPerProducer, appendSemaphore, None) + concurrentProduceFetchWithReadLockOnly() } /** @@ -95,7 +95,7 @@ class PartitionLockTest extends Logging { val active = new AtomicBoolean(true) val future = scheduleShrinkIsr(active, mockTimeSleepMs = 0) - concurrentProduceFetch(numProducers, numReplicaFetchers, numRecordsPerProducer, appendSemaphore, None) + concurrentProduceFetchWithReadLockOnly() active.set(false) future.get(15, TimeUnit.SECONDS) } @@ -111,62 +111,75 @@ class PartitionLockTest extends Logging { val future = scheduleShrinkIsr(active, mockTimeSleepMs = 10000) TestUtils.waitUntilTrue(() => shrinkIsrSemaphore.hasQueuedThreads, "shrinkIsr not invoked") - concurrentProduceFetch(numProducers, numReplicaFetchers, numRecordsPerProducer, appendSemaphore, Some(shrinkIsrSemaphore)) + concurrentProduceFetchWithWriteLock() active.set(false) future.get(15, TimeUnit.SECONDS) } - private def concurrentProduceFetch(numProducers: Int, - numReplicaFetchers: Int, - numRecords: Int, - appendSemaphore: Semaphore, - shrinkIsrSemaphore: Option[Semaphore]): Unit = { - val followerQueues = (0 until numReplicaFetchers).map(_ => new ArrayBlockingQueue[MemoryRecords](2)) + /** + * Perform concurrent appends and replica fetch requests that don't require write lock to + * update follower state. Release sufficient append permits to complete all except one append. + * Verify that follower state updates complete even though an append holding read lock is in progress. + * Then release the permit for the final append and verify that all appends and follower updates complete. + */ + private def concurrentProduceFetchWithReadLockOnly(): Unit = { + val appendFutures = scheduleAppends() + val stateUpdateFutures = scheduleUpdateFollowers(numProducers * numRecordsPerProducer - 1) + + appendSemaphore.release(numProducers * numRecordsPerProducer - 1) + stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS)) + + appendSemaphore.release(1) + scheduleUpdateFollowers(1).foreach(_.get(15, TimeUnit.SECONDS)) // just to make sure follower state update still works + appendFutures.foreach(_.get(15, TimeUnit.SECONDS)) + } + + /** + * Perform concurrent appends and replica fetch requests that may require write lock to update + * follower state. Threads waiting for write lock to update follower state while append thread is + * holding read lock will prevent other threads acquiring the read or write lock. So release sufficient + * permits for all appends to complete before verifying state updates. + */ + private def concurrentProduceFetchWithWriteLock(): Unit = { + + val appendFutures = scheduleAppends() + val stateUpdateFutures = scheduleUpdateFollowers(numProducers * numRecordsPerProducer) + + assertFalse(stateUpdateFutures.exists(_.isDone)) + appendSemaphore.release(numProducers * numRecordsPerProducer) + assertFalse(appendFutures.exists(_.isDone)) + + shrinkIsrSemaphore.release() + stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS)) + appendFutures.foreach(_.get(15, TimeUnit.SECONDS)) + } - val appendFutures = (0 until numProducers).map { _ => + private def scheduleAppends(): Seq[Future[_]] = { + (0 until numProducers).map { _ => executorService.submit((() => { try { append(partition, numRecordsPerProducer, followerQueues) } catch { case e: Throwable => - PartitionLockTest.this.error("Exception during append", e) + error("Exception during append", e) throw e } }): Runnable) } - def updateFollower(index: Int, numRecords: Int): Future[_] = { + } + + private def scheduleUpdateFollowers(numRecords: Int): Seq[Future[_]] = { + (1 to numReplicaFetchers).map { index => executorService.submit((() => { try { updateFollowerFetchState(partition, index, numRecords, followerQueues(index - 1)) } catch { case e: Throwable => - PartitionLockTest.this.error("Exception during updateFollowerFetchState", e) + error("Exception during updateFollowerFetchState", e) throw e } }): Runnable) } - val stateUpdateFutures = (1 to numReplicaFetchers).map { i => - updateFollower(i, numProducers * numRecordsPerProducer - 1) - } - - // Release sufficient append permits to complete all except one append. Verify that - // follower state update completes even though an update is in progress. Then release - // the permit for the final append and verify follower update for the last append. - appendSemaphore.release(numProducers * numRecordsPerProducer - 1) - - // If a semaphore that triggers contention is present, verify that state update futures - // haven't completed. Then release the semaphore to enable all threads to continue. - shrinkIsrSemaphore.foreach { semaphore => - assertFalse(stateUpdateFutures.exists(_.isDone)) - semaphore.release() - } - - stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS)) - appendSemaphore.release(1) - (1 to numReplicaFetchers).map { i => - updateFollower(i, 1).get(15, TimeUnit.SECONDS) - } - appendFutures.foreach(_.get(15, TimeUnit.SECONDS)) } private def scheduleShrinkIsr(activeFlag: AtomicBoolean, mockTimeSleepMs: Long): Future[_] = { @@ -202,7 +215,11 @@ class PartitionLockTest extends Logging { override def shrinkIsr(newIsr: Set[Int]): Unit = { shrinkIsrSemaphore.acquire() - super.shrinkIsr(newIsr) + try { + super.shrinkIsr(newIsr) + } finally { + shrinkIsrSemaphore.release() + } } override def createLog(replicaId: Int, isNew: Boolean, isFutureReplica: Boolean, offsetCheckpoints: OffsetCheckpoints): Log = { @@ -215,6 +232,8 @@ class PartitionLockTest extends Logging { .thenReturn(None) when(stateStore.shrinkIsr(ArgumentMatchers.anyInt, ArgumentMatchers.any[LeaderAndIsr])) .thenReturn(Some(2)) + when(stateStore.expandIsr(ArgumentMatchers.anyInt, ArgumentMatchers.any[LeaderAndIsr])) + .thenReturn(Some(2)) partition.createLogIfNotExists(brokerId, isNew = false, isFutureReplica = false, offsetCheckpoints) @@ -258,9 +277,12 @@ class PartitionLockTest extends Logging { val batch = followerQueue.poll(15, TimeUnit.SECONDS) if (batch == null) throw new RuntimeException(s"Timed out waiting for next batch $i") + val batches = batch.batches.iterator.asScala.toList + assertEquals(1, batches.size) + val recordBatch = batches.head partition.updateFollowerFetchState( followerId, - followerFetchOffsetMetadata = LogOffsetMetadata(i), + followerFetchOffsetMetadata = LogOffsetMetadata(recordBatch.lastOffset + 1), followerStartOffset = 0L, followerFetchTimeMs = mockTime.milliseconds(), leaderEndOffset = partition.localLogOrException.logEndOffset,