Browse Source

KAFKA-8334 Make sure the thread which tries to complete delayed reque… (#8657)

The main changes of this PR are shown below.

1. replace tryLock by lock for DelayedOperation#maybeTryComplete
2. complete the delayed requests without holding group lock

Reviewers: Ismael Juma <ismael@juma.me.uk>, Jun Rao <junrao@gmail.com>
pull/9278/head
Chia-Ping Tsai 4 years ago committed by GitHub
parent
commit
c2273adc25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 22
      core/src/main/scala/kafka/cluster/Partition.scala
  2. 11
      core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala
  3. 2
      core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
  4. 13
      core/src/main/scala/kafka/log/Log.scala
  5. 56
      core/src/main/scala/kafka/server/ActionQueue.scala
  6. 118
      core/src/main/scala/kafka/server/DelayedOperation.scala
  7. 4
      core/src/main/scala/kafka/server/KafkaApis.scala
  8. 31
      core/src/main/scala/kafka/server/ReplicaManager.scala
  9. 10
      core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala
  10. 65
      core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
  11. 5
      core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
  12. 96
      core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala

22
core/src/main/scala/kafka/cluster/Partition.scala

@ -103,19 +103,7 @@ class DelayedOperations(topicPartition: TopicPartition, @@ -103,19 +103,7 @@ class DelayedOperations(topicPartition: TopicPartition,
fetch.checkAndComplete(TopicPartitionOperationKey(topicPartition))
}
def checkAndCompleteProduce(): Unit = {
produce.checkAndComplete(TopicPartitionOperationKey(topicPartition))
}
def checkAndCompleteDeleteRecords(): Unit = {
deleteRecords.checkAndComplete(TopicPartitionOperationKey(topicPartition))
}
def numDelayedDelete: Int = deleteRecords.numDelayed
def numDelayedFetch: Int = fetch.numDelayed
def numDelayedProduce: Int = produce.numDelayed
}
object Partition extends KafkaMetricsGroup {
@ -1010,15 +998,7 @@ class Partition(val topicPartition: TopicPartition, @@ -1010,15 +998,7 @@ class Partition(val topicPartition: TopicPartition,
}
}
// some delayed operations may be unblocked after HW changed
if (leaderHWIncremented)
tryCompleteDelayedRequests()
else {
// probably unblock some follower fetch requests since log end offset has been updated
delayedOperations.checkAndCompleteFetch()
}
info
info.copy(leaderHwChange = if (leaderHWIncremented) LeaderHwChange.Increased else LeaderHwChange.Same)
}
def readRecords(fetchOffset: Long,

11
core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala

@ -36,8 +36,15 @@ private[group] class DelayedJoin(coordinator: GroupCoordinator, @@ -36,8 +36,15 @@ private[group] class DelayedJoin(coordinator: GroupCoordinator,
rebalanceTimeout: Long) extends DelayedOperation(rebalanceTimeout, Some(group.lock)) {
override def tryComplete(): Boolean = coordinator.tryCompleteJoin(group, forceComplete _)
override def onExpiration() = coordinator.onExpireJoin()
override def onComplete() = coordinator.onCompleteJoin(group)
override def onExpiration(): Unit = {
coordinator.onExpireJoin()
// try to complete delayed actions introduced by coordinator.onCompleteJoin
tryToCompleteDelayedAction()
}
override def onComplete(): Unit = coordinator.onCompleteJoin(group)
// TODO: remove this ugly chain after we move the action queue to handler thread
private def tryToCompleteDelayedAction(): Unit = coordinator.groupManager.replicaManager.tryCompleteActions()
}
/**

2
core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala

@ -56,7 +56,7 @@ import scala.jdk.CollectionConverters._ @@ -56,7 +56,7 @@ import scala.jdk.CollectionConverters._
class GroupMetadataManager(brokerId: Int,
interBrokerProtocolVersion: ApiVersion,
config: OffsetConfig,
replicaManager: ReplicaManager,
val replicaManager: ReplicaManager,
zkClient: KafkaZkClient,
time: Time,
metrics: Metrics) extends Logging with KafkaMetricsGroup {

13
core/src/main/scala/kafka/log/Log.scala

@ -68,6 +68,13 @@ object LogAppendInfo { @@ -68,6 +68,13 @@ object LogAppendInfo {
offsetsMonotonic = false, -1L, recordErrors, errorMessage)
}
sealed trait LeaderHwChange
object LeaderHwChange {
case object Increased extends LeaderHwChange
case object Same extends LeaderHwChange
case object None extends LeaderHwChange
}
/**
* Struct to hold various quantities we compute about each message set before appending to the log
*
@ -85,6 +92,9 @@ object LogAppendInfo { @@ -85,6 +92,9 @@ object LogAppendInfo {
* @param validBytes The number of valid bytes
* @param offsetsMonotonic Are the offsets in this message set monotonically increasing
* @param lastOffsetOfFirstBatch The last offset of the first batch
* @param leaderHwChange Incremental if the high watermark needs to be increased after appending record.
* Same if high watermark is not changed. None is the default value and it means append failed
*
*/
case class LogAppendInfo(var firstOffset: Option[Long],
var lastOffset: Long,
@ -100,7 +110,8 @@ case class LogAppendInfo(var firstOffset: Option[Long], @@ -100,7 +110,8 @@ case class LogAppendInfo(var firstOffset: Option[Long],
offsetsMonotonic: Boolean,
lastOffsetOfFirstBatch: Long,
recordErrors: Seq[RecordError] = List(),
errorMessage: String = null) {
errorMessage: String = null,
leaderHwChange: LeaderHwChange = LeaderHwChange.None) {
/**
* Get the first offset if it exists, else get the last offset of the first batch
* For magic versions 2 and newer, this method will return first offset. For magic versions

56
core/src/main/scala/kafka/server/ActionQueue.scala

@ -0,0 +1,56 @@ @@ -0,0 +1,56 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.server
import java.util.concurrent.ConcurrentLinkedQueue
import kafka.utils.Logging
/**
* This queue is used to collect actions which need to be executed later. One use case is that ReplicaManager#appendRecords
* produces record changes so we need to check and complete delayed requests. In order to avoid conflicting locking,
* we add those actions to this queue and then complete them at the end of KafkaApis.handle() or DelayedJoin.onExpiration.
*/
class ActionQueue extends Logging {
private val queue = new ConcurrentLinkedQueue[() => Unit]()
/**
* add action to this queue.
* @param action action
*/
def add(action: () => Unit): Unit = queue.add(action)
/**
* try to complete all delayed actions
*/
def tryCompleteActions(): Unit = {
val maxToComplete = queue.size()
var count = 0
var done = false
while (!done && count < maxToComplete) {
try {
val action = queue.poll()
if (action == null) done = true
else action()
} catch {
case e: Throwable =>
error("failed to complete delayed actions", e)
} finally count += 1
}
}
}

118
core/src/main/scala/kafka/server/DelayedOperation.scala

@ -41,13 +41,15 @@ import scala.collection.mutable.ListBuffer @@ -41,13 +41,15 @@ import scala.collection.mutable.ListBuffer
* forceComplete().
*
* A subclass of DelayedOperation needs to provide an implementation of both onComplete() and tryComplete().
*
* Noted that if you add a future delayed operation that calls ReplicaManager.appendRecords() in onComplete()
* like DelayedJoin, you must be aware that this operation's onExpiration() needs to call actionQueue.tryCompleteAction().
*/
abstract class DelayedOperation(override val delayMs: Long,
lockOpt: Option[Lock] = None)
extends TimerTask with Logging {
private val completed = new AtomicBoolean(false)
private val tryCompletePending = new AtomicBoolean(false)
// Visible for testing
private[server] val lock: Lock = lockOpt.getOrElse(new ReentrantLock)
@ -100,42 +102,24 @@ abstract class DelayedOperation(override val delayMs: Long, @@ -100,42 +102,24 @@ abstract class DelayedOperation(override val delayMs: Long,
def tryComplete(): Boolean
/**
* Thread-safe variant of tryComplete() that attempts completion only if the lock can be acquired
* without blocking.
*
* If threadA acquires the lock and performs the check for completion before completion criteria is met
* and threadB satisfies the completion criteria, but fails to acquire the lock because threadA has not
* yet released the lock, we need to ensure that completion is attempted again without blocking threadA
* or threadB. `tryCompletePending` is set by threadB when it fails to acquire the lock and at least one
* of threadA or threadB will attempt completion of the operation if this flag is set. This ensures that
* every invocation of `maybeTryComplete` is followed by at least one invocation of `tryComplete` until
* the operation is actually completed.
* Thread-safe variant of tryComplete() and call extra function if first tryComplete returns false
* @param f else function to be executed after first tryComplete returns false
* @return result of tryComplete
*/
private[server] def maybeTryComplete(): Boolean = {
var retry = false
var done = false
do {
if (lock.tryLock()) {
try {
tryCompletePending.set(false)
done = tryComplete()
} finally {
lock.unlock()
}
// While we were holding the lock, another thread may have invoked `maybeTryComplete` and set
// `tryCompletePending`. In this case we should retry.
retry = tryCompletePending.get()
} else {
// Another thread is holding the lock. If `tryCompletePending` is already set and this thread failed to
// acquire the lock, then the thread that is holding the lock is guaranteed to see the flag and retry.
// Otherwise, we should set the flag and retry on this thread since the thread holding the lock may have
// released the lock and returned by the time the flag is set.
retry = !tryCompletePending.getAndSet(true)
}
} while (!isCompleted && retry)
done
private[server] def safeTryCompleteOrElse(f: => Unit): Boolean = inLock(lock) {
if (tryComplete()) true
else {
f
// last completion check
tryComplete()
}
}
/**
* Thread-safe variant of tryComplete()
*/
private[server] def safeTryComplete(): Boolean = inLock(lock)(tryComplete())
/*
* run() method defines a task that is executed on timeout
*/
@ -219,38 +203,38 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri @@ -219,38 +203,38 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri
def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean = {
assert(watchKeys.nonEmpty, "The watch key list can't be empty")
// The cost of tryComplete() is typically proportional to the number of keys. Calling
// tryComplete() for each key is going to be expensive if there are many keys. Instead,
// we do the check in the following way. Call tryComplete(). If the operation is not completed,
// we just add the operation to all keys. Then we call tryComplete() again. At this time, if
// the operation is still not completed, we are guaranteed that it won't miss any future triggering
// event since the operation is already on the watcher list for all keys. This does mean that
// if the operation is completed (by another thread) between the two tryComplete() calls, the
// operation is unnecessarily added for watch. However, this is a less severe issue since the
// expire reaper will clean it up periodically.
// At this point the only thread that can attempt this operation is this current thread
// Hence it is safe to tryComplete() without a lock
var isCompletedByMe = operation.tryComplete()
if (isCompletedByMe)
return true
var watchCreated = false
for(key <- watchKeys) {
// If the operation is already completed, stop adding it to the rest of the watcher list.
if (operation.isCompleted)
return false
watchForOperation(key, operation)
if (!watchCreated) {
watchCreated = true
estimatedTotalOperations.incrementAndGet()
}
}
isCompletedByMe = operation.maybeTryComplete()
if (isCompletedByMe)
return true
// The cost of tryComplete() is typically proportional to the number of keys. Calling tryComplete() for each key is
// going to be expensive if there are many keys. Instead, we do the check in the following way through safeTryCompleteOrElse().
// If the operation is not completed, we just add the operation to all keys. Then we call tryComplete() again. At
// this time, if the operation is still not completed, we are guaranteed that it won't miss any future triggering
// event since the operation is already on the watcher list for all keys.
//
// ==============[story about lock]==============
// Through safeTryCompleteOrElse(), we hold the operation's lock while adding the operation to watch list and doing
// the tryComplete() check. This is to avoid a potential deadlock between the callers to tryCompleteElseWatch() and
// checkAndComplete(). For example, the following deadlock can happen if the lock is only held for the final tryComplete()
// 1) thread_a holds readlock of stateLock from TransactionStateManager
// 2) thread_a is executing tryCompleteElseWatch()
// 3) thread_a adds op to watch list
// 4) thread_b requires writelock of stateLock from TransactionStateManager (blocked by thread_a)
// 5) thread_c calls checkAndComplete() and holds lock of op
// 6) thread_c is waiting readlock of stateLock to complete op (blocked by thread_b)
// 7) thread_a is waiting lock of op to call the final tryComplete() (blocked by thread_c)
//
// Note that even with the current approach, deadlocks could still be introduced. For example,
// 1) thread_a calls tryCompleteElseWatch() and gets lock of op
// 2) thread_a adds op to watch list
// 3) thread_a calls op#tryComplete and tries to require lock_b
// 4) thread_b holds lock_b and calls checkAndComplete()
// 5) thread_b sees op from watch list
// 6) thread_b needs lock of op
// To avoid the above scenario, we recommend DelayedOperationPurgatory.checkAndComplete() be called without holding
// any exclusive lock. Since DelayedOperationPurgatory.checkAndComplete() completes delayed operations asynchronously,
// holding a exclusive lock to make the call is often unnecessary.
if (operation.safeTryCompleteOrElse {
watchKeys.foreach(key => watchForOperation(key, operation))
if (watchKeys.nonEmpty) estimatedTotalOperations.incrementAndGet()
}) return true
// if it cannot be completed by now and hence is watched, add to the expire queue also
if (!operation.isCompleted) {
@ -375,7 +359,7 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri @@ -375,7 +359,7 @@ final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: Stri
if (curr.isCompleted) {
// another thread has completed this operation, just remove it
iter.remove()
} else if (curr.maybeTryComplete()) {
} else if (curr.safeTryComplete()) {
iter.remove()
completed += 1
}

4
core/src/main/scala/kafka/server/KafkaApis.scala

@ -186,6 +186,10 @@ class KafkaApis(val requestChannel: RequestChannel, @@ -186,6 +186,10 @@ class KafkaApis(val requestChannel: RequestChannel,
case e: FatalExitError => throw e
case e: Throwable => handleError(request, e)
} finally {
// try to complete delayed action. In order to avoid conflicting locking, the actions to complete delayed requests
// are kept in a queue. We add the logic to check the ReplicaManager queue at the end of KafkaApis.handle() and the
// expiration thread for certain delayed operations (e.g. DelayedJoin)
replicaManager.tryCompleteActions()
// The local completion time may be set while processing the request. Only record it if it's unset.
if (request.apiLocalCompleteTimeNanos < 0)
request.apiLocalCompleteTimeNanos = time.nanoseconds

31
core/src/main/scala/kafka/server/ReplicaManager.scala

@ -558,10 +558,21 @@ class ReplicaManager(val config: KafkaConfig, @@ -558,10 +558,21 @@ class ReplicaManager(val config: KafkaConfig,
localLog(topicPartition).map(_.parentDir)
}
/**
* TODO: move this action queue to handle thread so we can simplify concurrency handling
*/
private val actionQueue = new ActionQueue
def tryCompleteActions(): Unit = actionQueue.tryCompleteActions()
/**
* Append messages to leader replicas of the partition, and wait for them to be replicated to other replicas;
* the callback function will be triggered either when timeout or the required acks are satisfied;
* if the callback function itself is already synchronized on some object then pass this object to avoid deadlock.
*
* Noted that all pending delayed check operations are stored in a queue. All callers to ReplicaManager.appendRecords()
* are expected to call ActionQueue.tryCompleteActions for all affected partitions, without holding any conflicting
* locks.
*/
def appendRecords(timeout: Long,
requiredAcks: Short,
@ -585,6 +596,26 @@ class ReplicaManager(val config: KafkaConfig, @@ -585,6 +596,26 @@ class ReplicaManager(val config: KafkaConfig,
result.info.logStartOffset, result.info.recordErrors.asJava, result.info.errorMessage)) // response status
}
actionQueue.add {
() =>
localProduceResults.foreach {
case (topicPartition, result) =>
val requestKey = TopicPartitionOperationKey(topicPartition)
result.info.leaderHwChange match {
case LeaderHwChange.Increased =>
// some delayed operations may be unblocked after HW changed
delayedProducePurgatory.checkAndComplete(requestKey)
delayedFetchPurgatory.checkAndComplete(requestKey)
delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.Same =>
// probably unblock some follower fetch requests since log end offset has been updated
delayedFetchPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.None =>
// nothing
}
}
}
recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults)) {

10
core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala

@ -17,8 +17,8 @@ @@ -17,8 +17,8 @@
package kafka.coordinator
import java.util.{Collections, Random}
import java.util.concurrent.{ConcurrentHashMap, Executors}
import java.util.{Collections, Random}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.Lock
@ -97,7 +97,7 @@ abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] { @@ -97,7 +97,7 @@ abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] {
}
def enableCompletion(): Unit = {
replicaManager.tryCompleteDelayedRequests()
replicaManager.tryCompleteActions()
scheduler.tick()
}
@ -166,9 +166,8 @@ object AbstractCoordinatorConcurrencyTest { @@ -166,9 +166,8 @@ object AbstractCoordinatorConcurrencyTest {
producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer, 1, reaperEnabled = false)
watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey, java.lang.Boolean]()).asScala
}
def tryCompleteDelayedRequests(): Unit = {
watchKeys.map(producePurgatory.checkAndComplete)
}
override def tryCompleteActions(): Unit = watchKeys.map(producePurgatory.checkAndComplete)
override def appendRecords(timeout: Long,
requiredAcks: Short,
@ -204,7 +203,6 @@ object AbstractCoordinatorConcurrencyTest { @@ -204,7 +203,6 @@ object AbstractCoordinatorConcurrencyTest {
val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq
watchKeys ++= producerRequestKeys
producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
tryCompleteDelayedRequests()
}
override def getMagic(topicPartition: TopicPartition): Option[Byte] = {
Some(RecordBatch.MAGIC_VALUE_V2)

65
core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala

@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
package kafka.coordinator.group
import java.util.Properties
import java.util.concurrent.locks.{Lock, ReentrantLock}
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import kafka.common.OffsetAndMetadata
@ -60,16 +61,6 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -60,16 +61,6 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
new LeaveGroupOperation
)
private val allOperationsWithTxn = Seq(
new JoinGroupOperation,
new SyncGroupOperation,
new OffsetFetchOperation,
new CommitTxnOffsetsOperation,
new CompleteTxnOperation,
new HeartbeatOperation,
new LeaveGroupOperation
)
var heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat] = _
var joinPurgatory: DelayedOperationPurgatory[DelayedJoin] = _
var groupCoordinator: GroupCoordinator = _
@ -119,12 +110,33 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -119,12 +110,33 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
@Test
def testConcurrentTxnGoodPathSequence(): Unit = {
verifyConcurrentOperations(createGroupMembers, allOperationsWithTxn)
verifyConcurrentOperations(createGroupMembers, Seq(
new JoinGroupOperation,
new SyncGroupOperation,
new OffsetFetchOperation,
new CommitTxnOffsetsOperation,
new CompleteTxnOperation,
new HeartbeatOperation,
new LeaveGroupOperation
))
}
@Test
def testConcurrentRandomSequence(): Unit = {
verifyConcurrentRandomSequences(createGroupMembers, allOperationsWithTxn)
/**
* handleTxnCommitOffsets does not complete delayed requests now so it causes error if handleTxnCompletion is executed
* before completing delayed request. In random mode, we use this global lock to prevent such an error.
*/
val lock = new ReentrantLock()
verifyConcurrentRandomSequences(createGroupMembers, Seq(
new JoinGroupOperation,
new SyncGroupOperation,
new OffsetFetchOperation,
new CommitTxnOffsetsOperation(lock = Some(lock)),
new CompleteTxnOperation(lock = Some(lock)),
new HeartbeatOperation,
new LeaveGroupOperation
))
}
@Test
@ -198,6 +210,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -198,6 +210,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
groupCoordinator.handleJoinGroup(member.groupId, member.memberId, None, requireKnownMemberId = false, "clientId", "clientHost",
DefaultRebalanceTimeout, DefaultSessionTimeout,
protocolType, protocols, responseCallback)
replicaManager.tryCompleteActions()
}
override def awaitAndVerify(member: GroupMember): Unit = {
val joinGroupResult = await(member, DefaultRebalanceTimeout)
@ -221,6 +234,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -221,6 +234,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId,
Some(protocolType), Some(protocolName), member.groupInstanceId, Map.empty[String, Array[Byte]], responseCallback)
}
replicaManager.tryCompleteActions()
}
override def awaitAndVerify(member: GroupMember): Unit = {
val result = await(member, DefaultSessionTimeout)
@ -238,6 +252,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -238,6 +252,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback): Unit = {
groupCoordinator.handleHeartbeat(member.groupId, member.memberId,
member.groupInstanceId, member.generationId, responseCallback)
replicaManager.tryCompleteActions()
}
override def awaitAndVerify(member: GroupMember): Unit = {
val error = await(member, DefaultSessionTimeout)
@ -252,6 +267,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -252,6 +267,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
}
override def runWithCallback(member: GroupMember, responseCallback: OffsetFetchCallback): Unit = {
val (error, partitionData) = groupCoordinator.handleFetchOffsets(member.groupId, requireStable = true, None)
replicaManager.tryCompleteActions()
responseCallback(error, partitionData)
}
override def awaitAndVerify(member: GroupMember): Unit = {
@ -271,6 +287,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -271,6 +287,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds()))
groupCoordinator.handleCommitOffsets(member.groupId, member.memberId,
member.groupInstanceId, member.generationId, offsets, responseCallback)
replicaManager.tryCompleteActions()
}
override def awaitAndVerify(member: GroupMember): Unit = {
val offsets = await(member, 500)
@ -278,7 +295,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -278,7 +295,7 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
}
}
class CommitTxnOffsetsOperation extends CommitOffsetsOperation {
class CommitTxnOffsetsOperation(lock: Option[Lock] = None) extends CommitOffsetsOperation {
override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = {
val tp = new TopicPartition("topic", 0)
val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds()))
@ -293,13 +310,17 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -293,13 +310,17 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
responseCallback(errors)
}
groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch,
JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID,
offsets, callbackWithTxnCompletion)
lock.foreach(_.lock())
try {
groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch,
JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID,
offsets, callbackWithTxnCompletion)
replicaManager.tryCompleteActions()
} finally lock.foreach(_.unlock())
}
}
class CompleteTxnOperation extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] {
class CompleteTxnOperation(lock: Option[Lock] = None) extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] {
override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback = {
val callback: CompleteTxnCallback = error => responsePromise.success(error)
callback
@ -307,9 +328,13 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest @@ -307,9 +328,13 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback): Unit = {
val producerId = 1000L
val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, _))
groupCoordinator.groupManager.handleTxnCompletion(producerId,
offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
responseCallback(Errors.NONE)
lock.foreach(_.lock())
try {
groupCoordinator.groupManager.handleTxnCompletion(producerId,
offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean)
responseCallback(Errors.NONE)
} finally lock.foreach(_.unlock())
}
override def awaitAndVerify(member: GroupMember): Unit = {
val error = await(member, 500)

5
core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala

@ -509,6 +509,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren @@ -509,6 +509,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
class InitProducerIdOperation(val producerIdAndEpoch: Option[ProducerIdAndEpoch] = None) extends TxnOperation[InitProducerIdResult] {
override def run(txn: Transaction): Unit = {
transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, producerIdAndEpoch, resultCallback)
replicaManager.tryCompleteActions()
}
override def awaitAndVerify(txn: Transaction): Unit = {
val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed"))
@ -525,6 +526,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren @@ -525,6 +526,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
txnMetadata.producerEpoch,
partitions,
resultCallback)
replicaManager.tryCompleteActions()
}
}
override def awaitAndVerify(txn: Transaction): Unit = {
@ -597,12 +599,13 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren @@ -597,12 +599,13 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
}
}
txnStateManager.enableTransactionalIdExpiration()
replicaManager.tryCompleteActions()
time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1)
}
override def await(): Unit = {
val (_, success) = TestUtils.computeUntilTrue({
replicaManager.tryCompleteDelayedRequests()
replicaManager.tryCompleteActions()
transactions.forall(txn => transactionMetadata(txn).isEmpty)
})(identity)
assertTrue("Transaction not expired", success)

96
core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala

@ -33,12 +33,12 @@ import scala.jdk.CollectionConverters._ @@ -33,12 +33,12 @@ import scala.jdk.CollectionConverters._
class DelayedOperationTest {
var purgatory: DelayedOperationPurgatory[MockDelayedOperation] = null
var purgatory: DelayedOperationPurgatory[DelayedOperation] = null
var executorService: ExecutorService = null
@Before
def setUp(): Unit = {
purgatory = DelayedOperationPurgatory[MockDelayedOperation](purgatoryName = "mock")
purgatory = DelayedOperationPurgatory[DelayedOperation](purgatoryName = "mock")
}
@After
@ -48,6 +48,43 @@ class DelayedOperationTest { @@ -48,6 +48,43 @@ class DelayedOperationTest {
executorService.shutdown()
}
@Test
def testLockInTryCompleteElseWatch(): Unit = {
val op = new DelayedOperation(100000L) {
override def onExpiration(): Unit = {}
override def onComplete(): Unit = {}
override def tryComplete(): Boolean = {
assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread)
false
}
override def safeTryComplete(): Boolean = {
fail("tryCompleteElseWatch should not use safeTryComplete")
super.safeTryComplete()
}
}
purgatory.tryCompleteElseWatch(op, Seq("key"))
}
@Test
def testSafeTryCompleteOrElse(): Unit = {
def op(shouldComplete: Boolean) = new DelayedOperation(100000L) {
override def onExpiration(): Unit = {}
override def onComplete(): Unit = {}
override def tryComplete(): Boolean = {
assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread)
shouldComplete
}
}
var pass = false
assertFalse(op(false).safeTryCompleteOrElse {
pass = true
})
assertTrue(pass)
assertTrue(op(true).safeTryCompleteOrElse {
fail("this method should NOT be executed")
})
}
@Test
def testRequestSatisfaction(): Unit = {
val r1 = new MockDelayedOperation(100000L)
@ -192,44 +229,6 @@ class DelayedOperationTest { @@ -192,44 +229,6 @@ class DelayedOperationTest {
assertEquals(Nil, cancelledOperations)
}
/**
* Verify that if there is lock contention between two threads attempting to complete,
* completion is performed without any blocking in either thread.
*/
@Test
def testTryCompleteLockContention(): Unit = {
executorService = Executors.newSingleThreadExecutor()
val completionAttemptsRemaining = new AtomicInteger(Int.MaxValue)
val tryCompleteSemaphore = new Semaphore(1)
val key = "key"
val op = new MockDelayedOperation(100000L, None, None) {
override def tryComplete() = {
val shouldComplete = completionAttemptsRemaining.decrementAndGet <= 0
tryCompleteSemaphore.acquire()
try {
if (shouldComplete)
forceComplete()
else
false
} finally {
tryCompleteSemaphore.release()
}
}
}
purgatory.tryCompleteElseWatch(op, Seq(key))
completionAttemptsRemaining.set(2)
tryCompleteSemaphore.acquire()
val future = runOnAnotherThread(purgatory.checkAndComplete(key), shouldComplete = false)
TestUtils.waitUntilTrue(() => tryCompleteSemaphore.hasQueuedThreads, "Not attempting to complete")
purgatory.checkAndComplete(key) // this should not block even though lock is not free
assertFalse("Operation should not have completed", op.isCompleted)
tryCompleteSemaphore.release()
future.get(10, TimeUnit.SECONDS)
assertTrue("Operation should have completed", op.isCompleted)
}
/**
* Test `tryComplete` with multiple threads to verify that there are no timing windows
* when completion is not performed even if the thread that makes the operation completable
@ -280,23 +279,6 @@ class DelayedOperationTest { @@ -280,23 +279,6 @@ class DelayedOperationTest {
ops.foreach { op => assertTrue("Operation should have completed", op.isCompleted) }
}
@Test
def testDelayedOperationLock(): Unit = {
verifyDelayedOperationLock(new MockDelayedOperation(100000L), mismatchedLocks = false)
}
@Test
def testDelayedOperationLockOverride(): Unit = {
def newMockOperation = {
val lock = new ReentrantLock
new MockDelayedOperation(100000L, Some(lock), Some(lock))
}
verifyDelayedOperationLock(newMockOperation, mismatchedLocks = false)
verifyDelayedOperationLock(new MockDelayedOperation(100000L, None, Some(new ReentrantLock)),
mismatchedLocks = true)
}
def verifyDelayedOperationLock(mockDelayedOperation: => MockDelayedOperation, mismatchedLocks: Boolean): Unit = {
val key = "key"
executorService = Executors.newSingleThreadExecutor

Loading…
Cancel
Save