diff --git a/app/src/main/java/org/thoughtcrime/securesms/service/webrtc/SignalCallManager.java b/app/src/main/java/org/thoughtcrime/securesms/service/webrtc/SignalCallManager.java index 5ec02b80ca..40f01462b2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/service/webrtc/SignalCallManager.java +++ b/app/src/main/java/org/thoughtcrime/securesms/service/webrtc/SignalCallManager.java @@ -14,6 +14,9 @@ import androidx.annotation.Nullable; import com.annimon.stream.Stream; import org.greenrobot.eventbus.EventBus; +import org.signal.core.models.ServiceId.ACI; +import org.signal.core.util.Util; +import org.signal.core.util.concurrent.KeyedSerialMonoLifoExecutor; import org.signal.core.util.concurrent.SignalExecutors; import org.signal.core.util.logging.Log; import org.signal.libsignal.zkgroup.GenericServerPublicParams; @@ -70,7 +73,6 @@ import org.thoughtcrime.securesms.service.webrtc.state.WebRtcServiceState; import org.thoughtcrime.securesms.util.AppForegroundObserver; import org.thoughtcrime.securesms.util.RecipientAccessList; import org.thoughtcrime.securesms.util.TextSecurePreferences; -import org.signal.core.util.Util; import org.thoughtcrime.securesms.util.rx.RxStore; import org.thoughtcrime.securesms.webrtc.CallNotificationBuilder; import org.thoughtcrime.securesms.webrtc.audio.SignalAudioManager; @@ -87,7 +89,6 @@ import org.whispersystems.signalservice.api.messages.calls.OpaqueMessage; import org.whispersystems.signalservice.api.messages.calls.SignalServiceCallMessage; import org.whispersystems.signalservice.api.messages.calls.TurnServerInfo; import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage; -import org.signal.core.models.ServiceId.ACI; import org.whispersystems.signalservice.api.push.exceptions.ProofRequiredException; import org.whispersystems.signalservice.api.push.exceptions.UnregisteredUserException; import org.whispersystems.signalservice.internal.push.SyncMessage; @@ -105,7 +106,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; -import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -113,9 +113,9 @@ import java.util.stream.Collectors; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.schedulers.Schedulers; +import kotlin.Pair; import kotlin.jvm.functions.Function1; import kotlin.text.Charsets; -import kotlin.Pair; import static org.thoughtcrime.securesms.events.WebRtcViewModel.GroupCallState.IDLE; import static org.thoughtcrime.securesms.events.WebRtcViewModel.State.CALL_INCOMING; @@ -136,10 +136,10 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. @Nullable private final CallManager callManager; - private final Context context; - private final ExecutorService serviceExecutor; - private final Executor networkExecutor; - private final LockManager lockManager; + private final Context context; + private final ExecutorService serviceExecutor; + private final KeyedSerialMonoLifoExecutor keyedExecutor; + private final LockManager lockManager; private WebRtcServiceState serviceState; private RxStore ephemeralStateStore; @@ -151,7 +151,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. this.context = application.getApplicationContext(); this.lockManager = new LockManager(this.context); this.serviceExecutor = Executors.newSingleThreadExecutor(); - this.networkExecutor = Executors.newSingleThreadExecutor(); + this.keyedExecutor = new KeyedSerialMonoLifoExecutor(SignalExecutors.BOUNDED_IO); this.ephemeralStateStore = new RxStore<>(new WebRtcEphemeralState(), Schedulers.from(serviceExecutor)); this.linkPeekInfoStore = new RxStore<>(new HashMap<>(), Schedulers.from(serviceExecutor)); @@ -411,7 +411,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. return; } - networkExecutor.execute(() -> { + keyedExecutor.execute(id.toString(), () -> { try { Recipient callLinkRecipient = Recipient.resolved(id); CallLinkRoomId callLinkRoomId = callLinkRecipient.requireCallLinkRoomId(); @@ -475,7 +475,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. return; } - networkExecutor.execute(() -> { + keyedExecutor.execute(id.toString(), () -> { try { Recipient group = Recipient.resolved(id); GroupId.V2 groupId = group.requireGroupId().requireV2(); @@ -515,7 +515,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. return; } - networkExecutor.execute(() -> { + keyedExecutor.execute("Call::" + info.getRecipientId(), () -> { try { Recipient group = Recipient.resolved(info.getRecipientId()); GroupId.V2 groupId = group.requireGroupId().requireV2(); @@ -538,7 +538,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. } void requestGroupMembershipToken(@NonNull GroupId.V2 groupId, int groupCallHashCode) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { ExternalGroupCredential credential = GroupManager.getExternalGroupCredential(context, groupId); process((s, p) -> p.handleGroupMembershipProofResponse(s, groupCallHashCode, credential.token.getBytes(Charsets.UTF_8))); @@ -871,7 +871,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. OpaqueMessage opaqueMessage = new OpaqueMessage(message, getUrgencyFromCallUrgency(urgency)); SignalServiceCallMessage callMessage = SignalServiceCallMessage.forOpaque(opaqueMessage, null); - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { Recipient recipient = Recipient.resolved(RecipientId.from(ACI.from(aciUuid))); if (recipient.isBlocked()) { return; @@ -901,7 +901,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. public void onSendCallMessageToGroup(@NonNull byte[] groupIdBytes, @NonNull byte[] message, @NonNull CallManager.CallMessageUrgency urgency, @NonNull List overrideRecipients) { Log.i(TAG, "onSendCallMessageToGroup():"); - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { GroupId groupId = GroupId.v2(new GroupIdentifier(groupIdBytes)); List recipients = SignalDatabase.groups().getGroupMembers(groupId, GroupTable.MemberSet.FULL_MEMBERS_EXCLUDING_SELF); @@ -957,7 +957,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. } Log.i(TAG, "onSendHttpRequest(): request_id: " + requestId); - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { List> headerPairs; if (headers != null) { headerPairs = Stream.of(headers) @@ -1130,7 +1130,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. } public void retrieveTurnServers(@NonNull RemotePeer remotePeer) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { List cachedServers = TurnServerCache.getCachedServers(); if (cachedServers != null) { @@ -1279,7 +1279,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. public void sendCallMessage(@NonNull final RemotePeer remotePeer, @NonNull final SignalServiceCallMessage callMessage) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { Recipient recipient = Recipient.resolved(remotePeer.getId()); if (recipient.isBlocked()) { return; @@ -1319,7 +1319,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. .updateOneToOneCall(remotePeer.getCallId().longValue(), CallTable.Event.ACCEPTED); if (SignalStore.account().isMultiDevice()) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { SyncMessage.CallEvent callEvent = CallEventSyncMessageUtil.createAcceptedSyncMessage(remotePeer, System.currentTimeMillis(), isOutgoing, isVideoCall); AppDependencies.getSignalServiceMessageSender().sendSyncMessage(SignalServiceSyncMessage.forCallEvent(callEvent)); @@ -1336,7 +1336,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. .updateOneToOneCall(remotePeer.getCallId().longValue(), CallTable.Event.NOT_ACCEPTED); if (SignalStore.account().isMultiDevice()) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { SyncMessage.CallEvent callEvent = CallEventSyncMessageUtil.createNotAcceptedSyncMessage(remotePeer, System.currentTimeMillis(), isOutgoing, isVideoCall); AppDependencies.getSignalServiceMessageSender().sendSyncMessage(SignalServiceSyncMessage.forCallEvent(callEvent)); @@ -1349,7 +1349,7 @@ public final class SignalCallManager implements CallManager.Observer, GroupCall. public void sendGroupCallNotAcceptedCallEventSyncMessage(@NonNull RemotePeer remotePeer, boolean isOutgoing) { if (SignalStore.account().isMultiDevice()) { - networkExecutor.execute(() -> { + SignalExecutors.BOUNDED_IO.execute(() -> { try { SyncMessage.CallEvent callEvent = CallEventSyncMessageUtil.createNotAcceptedSyncMessage(remotePeer, System.currentTimeMillis(), isOutgoing, true); AppDependencies.getSignalServiceMessageSender().sendSyncMessage(SignalServiceSyncMessage.forCallEvent(callEvent)); diff --git a/core/util/src/main/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutor.kt b/core/util/src/main/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutor.kt new file mode 100644 index 0000000000..0f78879bfd --- /dev/null +++ b/core/util/src/main/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutor.kt @@ -0,0 +1,64 @@ +package org.signal.core.util.concurrent + +import java.util.concurrent.Executor + +/** + * Like [org.thoughtcrime.securesms.util.concurrent.SerialMonoLifoExecutor], but manages independent queues keyed by a string. + * + * Each key gets its own active/next pair, so tasks with different keys can run concurrently on the + * backing executor. Within a given key, only two tasks exist at a time: the currently running one + * and the most recently enqueued one. Any previously-pending task for that key is replaced. + * + * Idle keys are cleaned up automatically when their work completes. + */ +class KeyedSerialMonoLifoExecutor(private val executor: Executor) { + + private val entries = mutableMapOf() + + @Synchronized + fun execute(key: String, command: Runnable) { + enqueue(key, command) + } + + /** + * @return True if a pending task for this key was replaced, otherwise false. + */ + @Synchronized + fun enqueue(key: String, command: Runnable): Boolean { + val entry = entries.getOrPut(key) { TaskEntry() } + val performedReplace = entry.next != null + + entry.next = Runnable { + try { + command.run() + } finally { + scheduleNext(key) + } + } + + if (entry.active == null) { + scheduleNext(key) + } + + return performedReplace + } + + @Synchronized + private fun scheduleNext(key: String) { + val entry = entries[key] ?: return + + entry.active = entry.next + entry.next = null + + if (entry.active != null) { + executor.execute(entry.active) + } else { + entries.remove(key) + } + } + + private class TaskEntry { + var active: Runnable? = null + var next: Runnable? = null + } +} diff --git a/core/util/src/test/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutorTest.kt b/core/util/src/test/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutorTest.kt new file mode 100644 index 0000000000..1183e37235 --- /dev/null +++ b/core/util/src/test/java/org/signal/core/util/concurrent/KeyedSerialMonoLifoExecutorTest.kt @@ -0,0 +1,180 @@ +package org.signal.core.util.concurrent + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import java.util.concurrent.Executor + +class KeyedSerialMonoLifoExecutorTest { + + @Test + fun `first task runs immediately`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + val task = TestRunnable() + + subject.execute("a", task) + + assertEquals(1, executor.pending()) + executor.runNext() + assertTrue(task.didRun) + } + + @Test + fun `second task is held until first completes`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + val first = TestRunnable() + val second = TestRunnable() + + subject.execute("a", first) + subject.execute("a", second) + + assertEquals(1, executor.pending()) + executor.runNext() + assertTrue(first.didRun) + assertFalse(second.didRun) + + assertEquals(1, executor.pending()) + executor.runNext() + assertTrue(second.didRun) + } + + @Test + fun `only the latest pending task is kept`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + val first = TestRunnable() + val replaced1 = TestRunnable() + val replaced2 = TestRunnable() + val latest = TestRunnable() + + subject.execute("a", first) + subject.execute("a", replaced1) + subject.execute("a", replaced2) + subject.execute("a", latest) + + executor.runNext() + assertTrue(first.didRun) + + executor.runNext() + assertTrue(latest.didRun) + assertFalse(replaced1.didRun) + assertFalse(replaced2.didRun) + + assertEquals(0, executor.pending()) + } + + @Test + fun `enqueue returns true when replacing a pending task`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + + val firstReplace = subject.enqueue("a", TestRunnable()) + assertFalse(firstReplace) + + val secondReplace = subject.enqueue("a", TestRunnable()) + assertFalse(secondReplace) + + val thirdReplace = subject.enqueue("a", TestRunnable()) + assertTrue(thirdReplace) + } + + @Test + fun `different keys dedupe independently`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + val a1 = TestRunnable() + val a2replaced = TestRunnable() + val a3 = TestRunnable() + val b1 = TestRunnable() + val b2 = TestRunnable() + + subject.execute("a", a1) + subject.execute("a", a2replaced) + subject.execute("a", a3) + subject.execute("b", b1) + subject.execute("b", b2) + + // a1 and b1 should both be dispatched + assertEquals(2, executor.pending()) + + executor.runNext() // a1 + assertTrue(a1.didRun) + + executor.runNext() // b1 + assertTrue(b1.didRun) + + executor.runNext() // a3 (a2replaced was dropped) + assertTrue(a3.didRun) + assertFalse(a2replaced.didRun) + + executor.runNext() // b2 + assertTrue(b2.didRun) + + assertEquals(0, executor.pending()) + } + + @Test + fun `idle keys are cleaned up`() { + val executor = TestExecutor() + val subject = KeyedSerialMonoLifoExecutor(executor) + + // Iteration 1: fill the queue (active + pending), drain it fully + val a1 = TestRunnable() + val a2 = TestRunnable() + subject.execute("a", a1) + subject.execute("a", a2) + executor.runNext() + executor.runNext() + assertTrue(a1.didRun) + assertTrue(a2.didRun) + assertEquals(0, executor.pending()) + + // Iteration 2: reuse the same key — should work with no stale state + val b1 = TestRunnable() + val b2 = TestRunnable() + subject.execute("a", b1) + subject.execute("a", b2) + executor.runNext() + executor.runNext() + assertTrue(b1.didRun) + assertTrue(b2.didRun) + assertEquals(0, executor.pending()) + + // Iteration 3: once more to confirm repeated cleanup + val c1 = TestRunnable() + val c2 = TestRunnable() + subject.execute("a", c1) + subject.execute("a", c2) + executor.runNext() + executor.runNext() + assertTrue(c1.didRun) + assertTrue(c2.didRun) + assertEquals(0, executor.pending()) + } + + private class TestExecutor : Executor { + private val tasks = ArrayDeque() + + override fun execute(command: Runnable) { + tasks.addLast(command) + } + + fun pending(): Int = tasks.size + + fun runNext() { + tasks.removeFirst().run() + } + } + + private class TestRunnable : Runnable { + var didRun = false + private set + + override fun run() { + didRun = true + } + } +}