Queue LibSignalChatConnection::sendRequest() in CONNECTING state.

This commit is contained in:
andrew-signal
2025-03-04 13:42:22 -05:00
committed by Greyson Parrelli
parent 38292f26b1
commit 165322afc1
5 changed files with 111 additions and 17 deletions

View File

@@ -10,6 +10,10 @@ import org.signal.libsignal.internal.CompletableFuture
/**
* A Kotlin friendly adapter for [org.signal.libsignal.internal.CompletableFuture.whenComplete]
* taking two callbacks ([onSuccess] and [onFailure]) instead of a [java.util.function.BiConsumer].
*
* Note that for libsignal's implementation of CompletableFuture, whenComplete will complete handlers in
* the order they are enqueued. This is a stronger guarantee than is given by the standard Java specification
* and is actively used by clients (e.g. LibSignalChatConnection) to reduce boilerplate in handling race conditions.
*/
fun <T> CompletableFuture<T>.whenComplete(
onSuccess: ((T?) -> Unit),

View File

@@ -26,6 +26,7 @@ import org.whispersystems.signalservice.api.websocket.HealthMonitor
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import org.whispersystems.signalservice.internal.util.whenComplete
import java.io.IOException
import java.net.SocketException
import java.time.Instant
import java.util.Optional
import java.util.concurrent.ConcurrentHashMap
@@ -38,7 +39,6 @@ import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest
import org.signal.libsignal.net.ChatConnection.Response as LibSignalResponse
/**
* Implements the WebSocketConnection interface via libsignal-net
@@ -120,15 +120,6 @@ class LibSignalChatConnection(
timeout.toInt()
)
}
private fun LibSignalResponse.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse {
return WebsocketResponse(
this.status,
this.body.decodeToString(),
this.headers,
isUnidentified
)
}
}
override val name = "[$name:${System.identityHashCode(this)}]"
@@ -269,14 +260,40 @@ class LibSignalChatConnection(
return Single.error(IOException("$name is closed!"))
}
// This avoids a crash loop when we try to send queued messages on app open before the connection
// is fully established.
// TODO [andrew]: Figure out if this is the right long term behavior.
val single = SingleSubject.create<WebsocketResponse>()
if (state.value == WebSocketConnectionState.CONNECTING) {
return Single.error(IOException("$name is still connecting!"))
// In OkHttpWebSocketConnection, if a client calls sendRequest while we are still
// connecting to the Chat service, we queue the request to be sent after the
// the connection is established.
// We carry forward that behavior here, except we have to use future chaining
// rather than directly writing to the connection for it to buffer for us,
// because libsignal-net does not expose a connection handle until the connection
// is established.
Log.i(TAG, "[sendRequest] Enqueuing request send for after connection")
// We are in the CONNECTING state, so our invariant says that chatConnectionFuture should
// be set, so we should not have to worry about nullability here.
chatConnectionFuture!!.whenComplete(
onSuccess = {
// We depend on the libsignal's CompletableFuture's synchronization guarantee to
// keep this implementation simple. If another CompletableFuture implementation is
// used, we'll need to add some logic here to be ensure this completion handler
// fires after the one enqueued in connect().
sendRequest(request)
.subscribe(
{ response -> single.onSuccess(response) },
{ error -> single.onError(error) }
)
},
onFailure = {
// This matches the behavior of OkHttpWebSocketConnection when the connection fails
// before the buffered request can be sent.
single.onError(SocketException("Closed unexpectedly"))
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
val single = SingleSubject.create<WebsocketResponse>()
val internalRequest = request.toLibSignalRequest()
chatConnection!!.send(internalRequest)
.whenComplete(
@@ -296,7 +313,10 @@ class LibSignalChatConnection(
},
onFailure = { throwable ->
Log.w(TAG, "$name [sendRequest] Failure:", throwable)
single.onError(throwable)
// The clients of WebSocketConnection are often sensitive to the exact type of exception returned.
// This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when
// the connection fails before the request completes.
single.onError(SocketException("Failed to get response for request"))
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())

View File

@@ -0,0 +1,17 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal.websocket
import org.signal.libsignal.net.ChatConnection.Response
fun Response.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse {
return WebsocketResponse(
this.status,
this.body.decodeToString(),
this.headers,
isUnidentified
)
}

View File

@@ -7,6 +7,7 @@ import org.whispersystems.signalservice.api.util.Preconditions;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class WebsocketResponse {
private final int status;
@@ -41,6 +42,19 @@ public class WebsocketResponse {
return unidentified;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final WebsocketResponse that = (WebsocketResponse) o;
return status == that.status && unidentified == that.unidentified && Objects.equals(body, that.body) && Objects.equals(headers, that.headers);
}
@Override
public int hashCode() {
return Objects.hash(status, body, headers, unidentified);
}
private static Map<String, String> parseHeaders(List<String> rawHeaders) {
Map<String, String> headers = new HashMap<>(rawHeaders.size());

View File

@@ -417,15 +417,54 @@ class LibSignalChatConnectionTest {
@Test
fun regressionTestSendWhileConnecting() {
var connectionCompletionFuture: CompletableFuture<UnauthenticatedChatConnection>? = null
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
delay {
// We do not complete the future, so we stay in the CONNECTING state forever.
connectionCompletionFuture = it
}
}
sendLatch = CountDownLatch(1)
connection.connect()
connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path"))
val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path"))
val sendObserver = sendSingle.test()
assertEquals(1, sendLatch!!.count)
sendObserver.assertNotComplete()
connectionCompletionFuture!!.complete(chatConnection)
sendLatch!!.await(100, TimeUnit.MILLISECONDS)
sendObserver.awaitDone(100, TimeUnit.MILLISECONDS)
sendObserver.assertValues(RESPONSE_SUCCESS.toWebsocketResponse(true))
}
@Test
fun testSendFailsWhenConnectionFails() {
var connectionCompletionFuture: CompletableFuture<UnauthenticatedChatConnection>? = null
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
delay {
connectionCompletionFuture = it
}
}
sendLatch = CountDownLatch(1)
connection.connect()
val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/fake-path"))
val sendObserver = sendSingle.test()
assertEquals(1, sendLatch!!.count)
sendObserver.assertNotComplete()
connectionCompletionFuture!!.completeExceptionally(ChatServiceException(""))
sendObserver.awaitDone(100, TimeUnit.MILLISECONDS)
assertEquals(1, sendLatch!!.count)
sendObserver.assertFailure(IOException().javaClass)
}
private fun <T> delay(action: ((CompletableFuture<T>) -> Unit)): CompletableFuture<T> {