diff --git a/src/sync.jl b/src/sync.jl index d353efcd7..03dd3b3fb 100644 --- a/src/sync.jl +++ b/src/sync.jl @@ -1,27 +1,35 @@ module Sync + mutable struct LockStatus + nested :: Int + owner :: Union{Task, Nothing} + end + const mutex = ReentrantLock() - const sync_level = repeat([ 0 ], Threads.nthreads()) + const lock_status = LockStatus(0, nothing) - @inline is_locked() = sync_level[Threads.threadid()] > 0 + @inline is_locked() = lock_status.owner == Base.current_task() @inline function lock() - tid = Threads.threadid() - if sync_level[tid] == 0 - Base.lock(mutex) - end - sync_level[tid] += 1 + if is_locked() + lock_status.nested += 1 + else + Base.lock(mutex) + lock_status.nested = 1 + lock_status.owner = Base.current_task() + end end @inline function unlock() - tid = Threads.threadid() - sync_level[tid] -= 1 - if sync_level[tid] == 0 - Base.unlock(mutex) - end + @assert is_locked() + lock_status.nested -= 1 + if lock_status.nested == 0 + lock_status.owner = nothing + Base.unlock(mutex) + end end @inline function check_lock() - assert(is_locked()) + @assert is_locked() end end