Skip to content

Commit

Permalink
Merge pull request #774 from arkivanov/refCount-race-fix
Browse files Browse the repository at this point in the history
Fixed a race condition in refCount
  • Loading branch information
CherryPerry authored Feb 24, 2024
2 parents a7a0ccd + 37cf3e5 commit e02f230
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package com.badoo.reaktive.observable

import com.badoo.reaktive.disposable.CompositeDisposable
import com.badoo.reaktive.disposable.Disposable
import com.badoo.reaktive.disposable.SerialDisposable
import com.badoo.reaktive.disposable.plusAssign
import com.badoo.reaktive.utils.atomic.AtomicInt
import com.badoo.reaktive.utils.atomic.AtomicReference
import com.badoo.reaktive.utils.atomic.getAndChange
import com.badoo.reaktive.utils.lock.Lock

/**
* Returns an [Observable] that connects to this [ConnectableObservable] when the number
Expand All @@ -16,23 +15,15 @@ import com.badoo.reaktive.utils.atomic.getAndChange
fun <T> ConnectableObservable<T>.refCount(subscriberCount: Int = 1): Observable<T> {
require(subscriberCount > 0)

val subscribeCount = AtomicInt()
val disposable = AtomicReference<Disposable?>(null)
var subscribeCount = 0
val lock = Lock()
val connectionDisposable = SerialDisposable()

return observable { emitter ->
val disposables = CompositeDisposable()
emitter.setDisposable(disposables)

disposables +=
Disposable {
if (subscribeCount.addAndGet(-1) == 0) {
disposable
.getAndChange { null }
?.dispose()
}
}

val shouldConnect = subscribeCount.addAndGet(1) == subscriberCount
val shouldConnect = lock.synchronized { ++subscribeCount == subscriberCount }

this@refCount.subscribe(
object : ObservableObserver<T>, ObservableCallbacks<T> by emitter {
Expand All @@ -43,9 +34,16 @@ fun <T> ConnectableObservable<T>.refCount(subscriberCount: Int = 1): Observable<
)

if (shouldConnect) {
this@refCount.connect {
disposable.value = it
}
this@refCount.connect(connectionDisposable::set)
}

disposables +=
Disposable {
lock.synchronized {
if (--subscribeCount == 0) {
connectionDisposable.set(null)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ class RefCountTest {
assertTrue(disposable.isDisposed)
}

@Test
fun connects_to_upstream_WHEN_subscriberCount_is_1_and_subscribed_and_disposed_in_onSubscribe() {
var isConnected = false
val upstream = testUpstream(connect = { isConnected = true })
val refCount = upstream.refCount(subscriberCount = 1)

refCount.subscribe(
object : DefaultObservableObserver<Int?> {
override fun onSubscribe(disposable: Disposable) {
disposable.dispose()
}
}
)

assertTrue(isConnected)
}

@Test
fun disconnects_from_upstream_WHEN_subscriberCount_is_1_and_subscribed_and_disposed_in_onSubscribe() {
val disposable = Disposable()
val upstream = testUpstream(connect = { onConnect -> onConnect?.invoke(disposable) })
val refCount = upstream.refCount(subscriberCount = 1)

refCount.subscribe(
object : DefaultObservableObserver<Int?> {
override fun onSubscribe(disposable: Disposable) {
disposable.dispose()
}
}
)

assertTrue(disposable.isDisposed)
}

@Test
fun disconnects_from_upstream_WHEN_subscriberCount_is_2_and_all_subscribers_unsubscribed() {
val disposable = Disposable()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.badoo.reaktive.observable

import com.badoo.reaktive.disposable.Disposable
import com.badoo.reaktive.test.doInBackground
import com.badoo.reaktive.test.observable.test
import com.badoo.reaktive.utils.lock.ConditionLock
import com.badoo.reaktive.utils.lock.synchronized
import com.badoo.reaktive.utils.lock.waitFor
import com.badoo.reaktive.utils.lock.waitForOrFail
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.time.Duration.Companion.seconds

class RefCountThreadingTest {

@Test
fun does_not_connect_second_time_concurrently_while_disconnecting() {
val lock = ConditionLock()
var isDisconnecting = false
var isSecondTime = false
var isConnectedSecondTimeConcurrently = false

val disposable =
Disposable {
lock.synchronized {
isDisconnecting = true
isSecondTime = true
lock.signal()
lock.waitFor(timeout = 1.seconds) { false }
isDisconnecting = false
}
}

val upstream =
testUpstream(
connect = { onConnect ->
lock.synchronized {
if (!isSecondTime) {
onConnect?.invoke(disposable)
} else {
isConnectedSecondTimeConcurrently = isDisconnecting
}
}
}
)

val refCount = upstream.refCount(subscriberCount = 1)
val observer = refCount.test()
doInBackground { observer.dispose() }

lock.synchronized {
lock.waitForOrFail { !isSecondTime }
}

refCount.test()

assertFalse(isConnectedSecondTimeConcurrently)
}

private fun testUpstream(
connect: (onConnect: ((Disposable) -> Unit)?) -> Unit = {},
): ConnectableObservable<Int?> =
object : ConnectableObservable<Int?> {
override fun connect(onConnect: ((Disposable) -> Unit)?) {
connect.invoke(onConnect)
}

override fun subscribe(observer: ObservableObserver<Int?>) {
observer.onSubscribe(Disposable())
}
}
}

0 comments on commit e02f230

Please sign in to comment.