Kotlin 协程源码阅读笔记 —— Mutex
我们在 Java
/ Kotlin
编程时如果需要某段代码块同一时间只有一个线程能够执行时,通常是使用 synchronized
,但是协程中可不能使用 synchronized
,为什么呢?如果你了解过协程的工作方式就不会觉得奇怪(如果不了解协程工作方式的同学,可以看
以下代码为什么不能正常工作呢(其中 hello()
和 world()
方法是 suspend
方法):
suspend fun helloWorld(): String {
synchronized(this) {
val hello = hello()
val world = world()
return hello + world
}
}
因为协程调用 suspend
方法后就相当于调用了一个异步函数,而后续的恢复执行时就相当于异步函数的回调成功。这个时候你想想:synchronized
可以在调用异步函数的时候获取锁,然后在回调成功的时候释放前面获取的锁吗?当然不行。那么我们要在协程中如何写出安全的代码(这里感觉还不能用线程安全这个词,因为协程中的锁不是基于线程的,或者叫协程安全?反正感觉也怪怪的)?答案是使用 Kotlin
协程中提供的 Mutex
(它是一种不允许重入的互斥锁)。上面的代码改成以下代码就可以执行了:
val lock = Mutex()
suspend fun helloWorld(): String {
lock.lock()
val hello = hello()
val world = world()
lock.unlock()
return hello + world
}
emmm,发散一下,那么通过 Java
中的可重入锁 ReentrantLock
可以对协程代码加锁吗?答案也是不行的,因为 ReentrantLock
的锁的状态也是针对线程的,在某个线程中获取了锁,也一定要在那个线程释放锁。而协程挂起前执行的线程和恢复后所在的线程不一定是同一个线程,所以不能使用 ReentrantLock
。举一个例子:协程中的代码执行的线程都是由 ContinuationInterceptor
来决定的,通常也是通过线程池来管理,在挂起前执行的线程,在执行挂起后,这个线程也就闲置了,闲置后他还可以去做其他的工作,比如在别的协程中做另外的工作,而挂起的协程恢复后需要线程执行,但是挂起前的线程有别的事情要做处于忙碌状态,这样 ContinuationInterceptor
就需要指定一个新的线程去执行恢复后的任务。所以这样挂起前后恢复后的线程就不一致了。
emmm,继续发散,像 Java
中线程安全的集合类,比如 ConcurrentHashMap
可以使用吗?可以使用,因为他们的内部又没有 suspend
函数,当然可以使用。只要没有使用 synchronized
或者 ReentrantLock
来锁 suspend
的函数都是没有问题的。
前面废话说的有点多,我们进入主题:Mutex
是如何工作的。
源码阅读基于 Kotlin 协程 1.8.0-RC2
获取锁
首先通过以下方法创建一个 Mutex
对象:
@Suppress("FunctionName")
public fun Mutex(locked: Boolean = false): Mutex =
MutexImpl(locked)
它的实现类是 MutexImpl
,它继承于 SemaphoreImpl
,这里不要和 Java
中的 Semaphore
搞混了,他们完全不是一回事儿。
我们看看 MutexImpl
获取锁的方法 lock()
:
override suspend fun lock(owner: Any?) {
// 尝试获取锁
if (tryLock(owner)) return
// 获取锁失败,挂起当前协程,直到锁可用后再恢复
lockSuspend(owner)
}
这里需要解释一下这个 owner
参数,它主要是用来 debug
发现我们代码中锁使用的问题的。lock()
和 unlock()
中传入的 owner
必须是同一个实例,如果两次不一样就会抛出异常。如果你在代码中成对调用 lock()
和 unlock()
的,那就不会有问题,也没有必要传一个 owner
对象。
首先通过 tryLock()
方法尝试去获取锁,如果获取成功就返回;如果获取失败通过 lockSuspend()
方法挂起当前的协程,等到锁被释放后再唤醒当前协程。
tryLock()
override fun tryLock(owner: Any?): Boolean = when (tryLockImpl(owner)) {
TRY_LOCK_SUCCESS -> true
TRY_LOCK_FAILED -> false
TRY_LOCK_ALREADY_LOCKED_BY_OWNER -> error("This mutex is already locked by the specified owner: $owner")
else -> error("unexpected")
}
这里继续调用 tryLockImpl()
方法去获取锁,如果返回值为 TRY_LOCK_SUCCESS
表示获取锁成功;TRY_LOCK_FAILED
表示获取锁失败;TRY_LOCK_ALREADY_LOCKED_BY_OWNER
表示同一个 owner
重复调用 lock()
方法,抛出异常。
我们再来看看 tryLockImpl()
方法:
private fun tryLockImpl(owner: Any?): Int {
while (true) {
// 获取锁
if (tryAcquire()) {
// 获取锁成功
assert { this.owner.value === NO_OWNER }
// 设置新的 owner
this.owner.value = owner
return TRY_LOCK_SUCCESS
} else {
// 获取锁失败
// The semaphore permit acquisition has failed.
// However, we need to check that this mutex is not
// locked by our owner.
if (owner == null) return TRY_LOCK_FAILED
// 检查 owner 状态
when (holdsLockImpl(owner)) {
// This mutex is already locked by our owner.
HOLDS_LOCK_YES -> return TRY_LOCK_ALREADY_LOCKED_BY_OWNER
// This mutex is locked by another owner, `trylock(..)` must return `false`.
HOLDS_LOCK_ANOTHER_OWNER -> return TRY_LOCK_FAILED
// This mutex is no longer locked, restart the operation.
HOLDS_LOCK_UNLOCKED -> continue
}
}
}
}
private fun holdsLockImpl(owner: Any?): Int {
while (true) {
// Is this mutex locked?
if (!isLocked) return HOLDS_LOCK_UNLOCKED
val curOwner = this.owner.value
// Wait in a spin-loop until the owner is set
if (curOwner === NO_OWNER) continue // <-- ATTENTION, BLOCKING PART HERE
// Check the owner
return if (curOwner === owner) HOLDS_LOCK_YES else HOLDS_LOCK_ANOTHER_OWNER
}
}
通过 tryAcquire()
尝试获取锁,后面的逻辑又分为两部分,获取锁成功和获取锁失败。
- 获取锁成功
代码很简单,直接更新owner
然后返回TRY_LOCK_SUCCESS
表示获取锁成功。 - 获取锁失败
如果是owner
为空,直接返回TRY_LOCK_FAILED
表示获取锁失败;如果owner
不为空,那么会通过holdsLockImpl()
方法检查owner
状态,根据不同的返回值有不同的处理方式:
-
HOLDS_LOCK_YES
表示当前owner
已经获取过锁了,表示重复获取锁。 -
HOLDS_LOCK_ANOTHER_OWNER
表示其他的owner
在持有锁,返回TRY_LOCK_FAILED
表示获取锁失败。 -
HOLDS_LOCK_UNLOCKED
表示当前并没有锁,然后重试调用tryAcquire()
继续获取锁(这种情况正好tryAcquire()
获取失败后,然后别的地方又释放了锁)。
继续看看 SemaphoreImpl#tryAcquire()
方法的实现:
override fun tryAcquire(): Boolean {
while (true) {
// Get the current number of available permits.
val p = _availablePermits.value
// Is the number of available permits greater
// than the maximal one because of an incorrect
// `release()` call without a preceding `acquire()`?
// Change it to `permits` and start from the beginning.
if (p > permits) {
coerceAvailablePermitsAtMaximum()
continue
}
// Try to decrement the number of available
// permits if it is greater than zero.
if (p <= 0) return false
if (_availablePermits.compareAndSet(p, p - 1)) return true
}
}
这里要解释一下 _availablePermits
,只有当它大于 0 时才可以获取锁,默认的 permits
是 1,也就是同时只有一个 lock()
方法能够获取到锁,其他的 lock()
方法只能等获取到锁的地方释放后才能继续执行。
当 _availablePermits
大于 0 时,通过 CAS
的方式把 _availablePermits
中的值减 1,如果 CAS
操作失败就重试,成功就直接返回 true
。在 Mutex
中很多地方用到了 CAS
自旋的方式去修改值,如果不懂的同学可以去网上找找 CAS
的概念。
lockSuspend()
private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable<Unit> { cont ->
val contWithOwner = CancellableContinuationWithOwner(cont, owner)
acquire(contWithOwner)
}
这里通过 suspendCancellableCoroutineReusable()
方法获取到了协程的 Continuation
对象,然后使用 CancellableContinuationWithOwner
对象将原来的 Continuation
封装了一下,然后继续调用 acquire()
方法。
继续看看 SemaphoreImpl#acquire()
方法的实现:
protected fun acquire(waiter: CancellableContinuation<Unit>) = acquire(
waiter = waiter,
suspend = { cont -> addAcquireToQueue(cont as Waiter) },
onAcquired = { cont -> cont.resume(Unit, onCancellationRelease) }
)
private inline fun <W> acquire(waiter: W, suspend: (waiter: W) -> Boolean, onAcquired: (waiter: W) -> Unit) {
while (true) {
// Decrement the number of available permits at first.
// 先获取 _availablePermits 值,然后再减 1
val p = decPermits()
// Is the permit acquired?
if (p > 0) {
// 表示获取锁成功
onAcquired(waiter)
return
}
// Permit has not been acquired, try to suspend.
// 执行挂起操作。
if (suspend(waiter)) return
}
}
suspend
与 onAcquired
这两个函数对象分别表示挂起操作和获取锁成功的操作。suspend
中通过 addAcquireToQueue()
方法将当前 Continaution
对象添加到等待队列;而 onAcquired
就非常简单了直接通过 Continuation#resume()
方法恢复协程。
我们看看 addAcquireToQueue()
方法的实现:
private fun addAcquireToQueue(waiter: Waiter): Boolean {
val curTail = this.tail.value
// 获取 id
val enqIdx = enqIdx.getAndIncrement()
// 获取创建 Segment 的方法
val createNewSegment = ::createSegment
// 从链表尾部开始查找 Segment,如果没有查找到就创建一个新的 Semgent
val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
createNewSegment = createNewSegment).segment // cannot be closed
// 计算 Continuation 存储在 Segemnt 中的 index
val i = (enqIdx % SEGMENT_SIZE).toInt()
// the regular (fast) path -- if the cell is empty, try to install continuation
// 通过 CAS 的方式将 Continuation 添加到 Segment 中去
if (segment.cas(i, null, waiter)) { // installed continuation successfully
// 添加成功,注册 Continuation 被取消的监听,取消后会通知 Segment
waiter.invokeOnCancellation(segment, i)
return true
}
// ...
// CAS 操作失败,在 acquire() 方法中会进行重试。
return false // broken cell, need to retry on a different cell
}
private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
和等待队列相关的代码就要复杂一丢丢了,先解释一下,我们的等待的 Continuation
是存放在 Segment
中的,每个 Segment
最多能够存放 SEGMENT_SIZE
(默认 16) 个 Continuation
,存放的方式是数组, Segment
存放满了,就会创建新的 Segment
,Segment
之间是链表的存储方式。
简单整理下上面方法的流程:
- 通过
enqIdx
用来计算入队的id
,它是一个原子类,依次递增的。 - 通过
id / SEGMENT_SIZE
的方式计算Segment
的id
,然后通过方法findSegmentAndMoveForward()
从Segment
链表尾开始查找一个可用的Segment
,如果没有可用的了,就通过createSegment()
方法创建一个新的。 - 通过
id % SEGEMNT_SIZE
的方式计算出Continuation
存放在Segment
中的数组的位置。 - 通过
CAS
的方法将Continuation
添加到Segment
中,如果添加成功会调用Continuaiton#invokeOnCancellation()
方法来监听协程取消时的消息;如果添加失败会触发acquire()
方法重试。
我们再简单看看 findSegmentAndMoveForward()
方法是如何查找和创建一个新的 Segment
的:
@Suppress("NOTHING_TO_INLINE")
internal inline fun <S : Segment<S>> AtomicRef<S>.findSegmentAndMoveForward(
id: Long,
startFrom: S,
noinline createNewSegment: (id: Long, prev: S) -> S
): SegmentOrClosed<S> {
while (true) {
val s = startFrom.findSegmentInternal(id, createNewSegment)
// 检查查询到的 Segment 状态
if (s.isClosed || moveForward(s.segment)) return s
}
}
internal fun <S : Segment<S>> S.findSegmentInternal(
id: Long,
createNewSegment: (id: Long, prev: S) -> S
): SegmentOrClosed<S> {
var cur: S = this
// 当前的 id 如果小于目标 id 或者当前已经 remove了 执行查找
while (cur.id < id || cur.isRemoved) {
val next = cur.nextOrIfClosed { return SegmentOrClosed(CLOSED) }
// 如果 next 为空就表示需要创建新的 Segment,反之继续进入循环判断 id
if (next != null) { // there is a next node -- move there
cur = next
continue
}
// 创建一个新的 Segemnt
val newTail = createNewSegment(cur.id + 1, cur)
// 将旧的 tail 的 next 指向新的 Segemnt
if (cur.trySetNext(newTail)) { // successfully added new node -- move there
if (cur.isRemoved) cur.remove()
cur = newTail
}
}
return SegmentOrClosed(cur)
}
@Suppress("NOTHING_TO_INLINE", "RedundantNullableReturnType") // Must be inline because it is an AtomicRef extension
internal inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
if (cur.id >= to.id) return true
if (!to.tryIncPointers()) return false
if (compareAndSet(cur, to)) { // the segment is moved
if (cur.decPointers()) cur.remove()
return true
}
if (to.decPointers()) to.remove() // undo tryIncPointers
}
获取/创建 Segment
的代码我就不重点介绍了,我添加了一些注释,大家自己看看。
到这里我们知道了如果 lock()
获取锁失败,对应的 Continaution
就会被挂起,然后 Continaution
对象会被添加到 Segment
中。聪明的你应该也猜到了,获取到锁的协程调用 unlock()
方法后就会尝试恢复 Segment
中的一个 Continaution
那么对应的调用 lock()
的协程就可以恢复执行了,是的,确实是这样,我们继续看看后面是怎么释放锁的。
释放锁
override fun unlock(owner: Any?) {
while (true) {
// Is this mutex locked?
// 检查锁状态
check(isLocked) { "This mutex is not locked" }
// Read the owner, waiting until it is set in a spin-loop if required.
// 检查 owner 状态
val curOwner = this.owner.value
if (curOwner === NO_OWNER) continue // <-- ATTENTION, BLOCKING PART HERE
// Check the owner.
check(curOwner === owner || owner == null) { "This mutex is locked by $curOwner, but $owner is expected" }
// Try to clean the owner first. We need to use CAS here to synchronize with concurrent `unlock(..)`-s.
// 将 owner 状态修改成 NO_OWNER
if (!this.owner.compareAndSet(curOwner, NO_OWNER)) continue
// Release the semaphore permit at the end.
// 释放锁操作
release()
return
}
}
上面代码比较简单,检查锁状态;检查 owner
状态;将 owner
设置为 NO_OWNER
;如果设置成功调用 release()
方法执行释放操作。上面的一些 CAS
操作如果失败,然后会再重试。
我们再看看 SemaphoreImpl#release()
方法的实现:
override fun release() {
while (true) {
// 获取 _availablePermits 后,再把它的值加 1
val p = _availablePermits.getAndIncrement()
if (p >= permits) {
// Revert the number of available permits
// back to the correct one and fail with error.
coerceAvailablePermitsAtMaximum()
error("The number of released permits cannot be greater than $permits")
}
// 如果 p 大于等于 0,表示没有 Continuation 在等待锁,直接返回
if (p >= 0) return
// 尝试从等待队列中恢复一个 Continuation
if (tryResumeNextFromQueue()) return
}
}
简单解释一下上面代码:
- 获取
_availablePermits
的值,并把它的值加 1. - 如果上次
_availablePermits
大于等于 0 就表示没有Continuation
在等待锁,反之就是有等待锁. - 如果有
Continuation
在等待锁,通过tryResumeNextFromQueue()
方法尝试从等待队列中获取等待最久的一个Continuation
来获取锁,并恢复它。
我们继续看看 tryResumeNextFromQueue()
方法的实现:
private fun tryResumeNextFromQueue(): Boolean {
val curHead = this.head.value
// 从 deqIdx 中获取基础 id
val deqIdx = deqIdx.getAndIncrement()
// 计算 Segment 的 id
val id = deqIdx / SEGMENT_SIZE
val createNewSegment = ::createSegment
// 和插入队列一样,先去查找一个 Segment,这里不同的是从 head 开始查找
val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
createNewSegment = createNewSegment).segment // cannot be closed
segment.cleanPrev()
// 判断查找到的 Segemnt 的 id 和目标的 id 是否对应,如果不对应就继续查找
if (segment.id > id) return false
// 获取对应的 Continuation 在 Segemnt 中对应的位置
val i = (deqIdx % SEGMENT_SIZE).toInt()
// 获取到对应的 Continuation
val cellState = segment.getAndSet(i, PERMIT) // set PERMIT and retrieve the prev cell state
when {
cellState === null -> {
// Acquire has not touched this cell yet, wait until it comes for a bounded time
// The cell state can only transition from PERMIT to TAKEN by addAcquireToQueue
repeat(MAX_SPIN_CYCLES) {
if (segment.get(i) === TAKEN) return true
}
// Try to break the slot in order not to wait
return !segment.cas(i, PERMIT, BROKEN)
}
cellState === CANCELLED -> return false // the acquirer has already been cancelled
// 恢复对应的 Continuation
else -> return cellState.tryResumeAcquire()
}
}
和插入队列时的代码类似,通过 deqIdx
来生成对应的目标的 Segment
的 id
和 Continuation
在 Segment
中的 index
(这里注意插入的时候是用的 enqIdx
,他们都是每次获取都会加 1)。查询 Semgent
同样是使用的 findSegmentAndMoveForward()
方法,不同的是出队列是从 head
开始查询。查询到对应的 Segment
后,会通过它的 getAndSet()
来获取一个 Continuation
(正常情况下是一个 Continuation
),然后调用它的 tryResumeAcquire()
方法,我们再来看看它的实现:
private fun Any.tryResumeAcquire(): Boolean = when(this) {
is CancellableContinuation<*> -> {
this as CancellableContinuation<Unit>
// 尝试 Resume
val token = tryResume(Unit, null, onCancellationRelease)
if (token != null) {
// 可以 Resume,通过 completeResume() 方法确认执行下发 Resume
completeResume(token)
true
} else false
}
is SelectInstance<*> -> {
trySelect(this@SemaphoreImpl, Unit)
}
else -> error("unexpected: $this")
}
我们的 Continuation
就是一个 CancellableContinuation
,首先会通过 tryResume()
方法去尝试恢复协程,如果返回的 token
不为空,就表示当前的协程可以恢复,然后通过 completeResume()
方法确认执行协程恢复。
我们看看 CancellableContinuationWithOwner#tryResume()
方法的实现:
override fun tryResume(value: Unit, idempotent: Any?, onCancellation: ((cause: Throwable) -> Unit)?): Any{
// 校验 owner 状态是 NO_OWNER
assert { this@MutexImpl.owner.value === NO_OWNER }
// 执行被代理的 Continaution 的 tryResume 方法
val token = cont.tryResume(value, idempotent) {
// 这个 Lambda 会在 Dispatcher 把任务取消时才执行,也就是表示 resume 失败了。
assert { this@MutexImpl.owner.value.let { it === NO_OWNER ||it === owner } }
this@MutexImpl.owner.value = owner
// 重新解锁
unlock(owner)
}
// token 不为空表示 tryResume 成功
if (token != null) {
assert { this@MutexImpl.owner.value === NO_OWNER }
// 将 owner 修改为当前 Continaution 的 owner.
this@MutexImpl.owner.value = owner
}
return token
}
上面代码很简单了,就是调用被代理 Continaution
的 tryResume()
方法,如果返回值不为空就表示协程恢复成功(如果 Dispatcher
将该 resume
任务取消就会解锁),协程恢复成功就会将它对应的 owner
添加到 MutexImpl
中去,表示由该 owner
占有当前锁。