From 165322afc17b48ef80f5e3a2973bc5a13bdbc898 Mon Sep 17 00:00:00 2001 From: andrew-signal Date: Tue, 4 Mar 2025 13:42:22 -0500 Subject: [PATCH] Queue LibSignalChatConnection::sendRequest() in CONNECTING state. --- .../util/CompletableFutureExtensions.kt | 4 ++ .../websocket/LibSignalChatConnection.kt | 52 +++++++++++++------ .../websocket/LibSignalResponseExtension.kt | 17 ++++++ .../internal/websocket/WebsocketResponse.java | 14 +++++ .../websocket/LibSignalChatConnectionTest.kt | 41 ++++++++++++++- 5 files changed, 111 insertions(+), 17 deletions(-) create mode 100644 libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalResponseExtension.kt diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/util/CompletableFutureExtensions.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/util/CompletableFutureExtensions.kt index d3bed62f12..aee321e81e 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/util/CompletableFutureExtensions.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/util/CompletableFutureExtensions.kt @@ -10,6 +10,10 @@ import org.signal.libsignal.internal.CompletableFuture /** * A Kotlin friendly adapter for [org.signal.libsignal.internal.CompletableFuture.whenComplete] * taking two callbacks ([onSuccess] and [onFailure]) instead of a [java.util.function.BiConsumer]. + * + * Note that for libsignal's implementation of CompletableFuture, whenComplete will complete handlers in + * the order they are enqueued. This is a stronger guarantee than is given by the standard Java specification + * and is actively used by clients (e.g. LibSignalChatConnection) to reduce boilerplate in handling race conditions. */ fun CompletableFuture.whenComplete( onSuccess: ((T?) -> Unit), diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index c6fc372da2..4f860f50f8 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -26,6 +26,7 @@ import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.internal.util.whenComplete import java.io.IOException +import java.net.SocketException import java.time.Instant import java.util.Optional import java.util.concurrent.ConcurrentHashMap @@ -38,7 +39,6 @@ import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock import kotlin.time.Duration.Companion.seconds import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest -import org.signal.libsignal.net.ChatConnection.Response as LibSignalResponse /** * Implements the WebSocketConnection interface via libsignal-net @@ -120,15 +120,6 @@ class LibSignalChatConnection( timeout.toInt() ) } - - private fun LibSignalResponse.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse { - return WebsocketResponse( - this.status, - this.body.decodeToString(), - this.headers, - isUnidentified - ) - } } override val name = "[$name:${System.identityHashCode(this)}]" @@ -269,14 +260,40 @@ class LibSignalChatConnection( return Single.error(IOException("$name is closed!")) } - // This avoids a crash loop when we try to send queued messages on app open before the connection - // is fully established. - // TODO [andrew]: Figure out if this is the right long term behavior. + val single = SingleSubject.create() + if (state.value == WebSocketConnectionState.CONNECTING) { - return Single.error(IOException("$name is still connecting!")) + // In OkHttpWebSocketConnection, if a client calls sendRequest while we are still + // connecting to the Chat service, we queue the request to be sent after the + // the connection is established. + // We carry forward that behavior here, except we have to use future chaining + // rather than directly writing to the connection for it to buffer for us, + // because libsignal-net does not expose a connection handle until the connection + // is established. + Log.i(TAG, "[sendRequest] Enqueuing request send for after connection") + // We are in the CONNECTING state, so our invariant says that chatConnectionFuture should + // be set, so we should not have to worry about nullability here. + chatConnectionFuture!!.whenComplete( + onSuccess = { + // We depend on the libsignal's CompletableFuture's synchronization guarantee to + // keep this implementation simple. If another CompletableFuture implementation is + // used, we'll need to add some logic here to be ensure this completion handler + // fires after the one enqueued in connect(). + sendRequest(request) + .subscribe( + { response -> single.onSuccess(response) }, + { error -> single.onError(error) } + ) + }, + onFailure = { + // This matches the behavior of OkHttpWebSocketConnection when the connection fails + // before the buffered request can be sent. + single.onError(SocketException("Closed unexpectedly")) + } + ) + return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) } - val single = SingleSubject.create() val internalRequest = request.toLibSignalRequest() chatConnection!!.send(internalRequest) .whenComplete( @@ -296,7 +313,10 @@ class LibSignalChatConnection( }, onFailure = { throwable -> Log.w(TAG, "$name [sendRequest] Failure:", throwable) - single.onError(throwable) + // The clients of WebSocketConnection are often sensitive to the exact type of exception returned. + // This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when + // the connection fails before the request completes. + single.onError(SocketException("Failed to get response for request")) } ) return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalResponseExtension.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalResponseExtension.kt new file mode 100644 index 0000000000..6979a84bf3 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalResponseExtension.kt @@ -0,0 +1,17 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.internal.websocket + +import org.signal.libsignal.net.ChatConnection.Response + +fun Response.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse { + return WebsocketResponse( + this.status, + this.body.decodeToString(), + this.headers, + isUnidentified + ) +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java index 940aa0aba8..baef42b4bc 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java @@ -7,6 +7,7 @@ import org.whispersystems.signalservice.api.util.Preconditions; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; public class WebsocketResponse { private final int status; @@ -41,6 +42,19 @@ public class WebsocketResponse { return unidentified; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final WebsocketResponse that = (WebsocketResponse) o; + return status == that.status && unidentified == that.unidentified && Objects.equals(body, that.body) && Objects.equals(headers, that.headers); + } + + @Override + public int hashCode() { + return Objects.hash(status, body, headers, unidentified); + } + private static Map parseHeaders(List rawHeaders) { Map headers = new HashMap<>(rawHeaders.size()); diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index a6c7c6cee1..d117cc0576 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -417,15 +417,54 @@ class LibSignalChatConnectionTest { @Test fun regressionTestSendWhileConnecting() { + var connectionCompletionFuture: CompletableFuture? = null every { network.connectUnauthChat(any()) } answers { chatListener = firstArg() delay { // We do not complete the future, so we stay in the CONNECTING state forever. + connectionCompletionFuture = it } } + sendLatch = CountDownLatch(1) connection.connect() - connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path")) + + val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path")) + val sendObserver = sendSingle.test() + + assertEquals(1, sendLatch!!.count) + sendObserver.assertNotComplete() + + connectionCompletionFuture!!.complete(chatConnection) + + sendLatch!!.await(100, TimeUnit.MILLISECONDS) + sendObserver.awaitDone(100, TimeUnit.MILLISECONDS) + sendObserver.assertValues(RESPONSE_SUCCESS.toWebsocketResponse(true)) + } + + @Test + fun testSendFailsWhenConnectionFails() { + var connectionCompletionFuture: CompletableFuture? = null + every { network.connectUnauthChat(any()) } answers { + chatListener = firstArg() + delay { + connectionCompletionFuture = it + } + } + sendLatch = CountDownLatch(1) + + connection.connect() + val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path")) + val sendObserver = sendSingle.test() + + assertEquals(1, sendLatch!!.count) + sendObserver.assertNotComplete() + + connectionCompletionFuture!!.completeExceptionally(ChatServiceException("")) + + sendObserver.awaitDone(100, TimeUnit.MILLISECONDS) + assertEquals(1, sendLatch!!.count) + sendObserver.assertFailure(IOException().javaClass) } private fun delay(action: ((CompletableFuture) -> Unit)): CompletableFuture {