Refactor LibSignalChatConnection to use an explicit queue sendRequest handling while CONNECTING.

This commit is contained in:
andrew-signal
2025-06-05 19:46:11 -04:00
committed by Cody Henthorne
parent 89944d778b
commit 7bd52e661d
4 changed files with 249 additions and 132 deletions

View File

@@ -1005,7 +1005,7 @@ object RemoteConfig {
@JvmStatic
@get:JvmName("libSignalWebSocketEnabled")
val libSignalWebSocketEnabled: Boolean by remoteValue(
key = "android.libsignalWebSocketEnabled.4",
key = "android.libsignalWebSocketEnabled.5",
hotSwappable = false
) { value ->
value.asBoolean(false) || Environment.IS_NIGHTLY

View File

@@ -7,7 +7,6 @@ public enum WebSocketConnectionState {
DISCONNECTED,
CONNECTING,
CONNECTED,
RECONNECTING,
DISCONNECTING,
AUTHENTICATION_FAILED,
REMOTE_DEPRECATED,

View File

@@ -81,8 +81,14 @@ class LibSignalChatConnection(
private val nextIncomingMessageInternalPseudoId = AtomicLong(1)
val ackSenderForInternalPseudoId = ConcurrentHashMap<Long, ChatConnectionListener.ServerMessageAck>()
// CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, and
// chatConnectionFuture
private data class RequestAwaitingConnection(
val request: WebSocketRequestMessage,
val timeoutSeconds: Long,
val single: SingleSubject<WebsocketResponse>
)
// CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection,
// chatConnectionFuture, and requestsAwaitingConnection.
// stateChangedOrMessageReceivedCondition: derived from CHAT_SERVICE_LOCK, used by readRequest(),
// exists to emulate idiosyncratic behavior of OkHttpWebSocketConnection for readRequest()
// chatConnection: Set only when state == CONNECTED
@@ -92,6 +98,10 @@ class LibSignalChatConnection(
private var chatConnection: ChatConnection? = null
private var chatConnectionFuture: CompletableFuture<out ChatConnection>? = null
// requestsAwaitingConnection should only have contents when we are transitioning to, out of, or are
// in the CONNECTING state.
private val requestsAwaitingConnection = mutableListOf<RequestAwaitingConnection>()
companion object {
const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT"
const val SERVICE_ENVELOPE_REQUEST_PATH = "/api/v1/message"
@@ -133,11 +143,11 @@ class LibSignalChatConnection(
val stateMonitor = state
.skip(1) // Skip the transition to the initial DISCONNECTED state
.subscribe { nextState ->
if (nextState == WebSocketConnectionState.DISCONNECTED) {
cleanup()
}
CHAT_SERVICE_LOCK.withLock {
if (nextState == WebSocketConnectionState.DISCONNECTED) {
cleanup()
}
stateChangedOrMessageReceivedCondition.signalAll()
}
}
@@ -150,6 +160,17 @@ class LibSignalChatConnection(
// there is no ackSender for a pseudoId gracefully in sendResponse.
ackSenderForInternalPseudoId.clear()
// There's no sense in resetting nextIncomingMessageInternalPseudoId.
// This is a belt-and-suspenders check, because the transition handler leaving the CONNECTING
// state should always cleanup the requestsAwaitingConnection, but in case we miss one, log it
// as an error and clean it up gracefully
if (requestsAwaitingConnection.isNotEmpty()) {
Log.w(TAG, "$name [cleanup] ${requestsAwaitingConnection.size} requestsAwaitingConnection during cleanup! This is probably a bug.")
requestsAwaitingConnection.forEach { pending ->
pending.single.onError(SocketException("Connection terminated unexpectedly"))
}
requestsAwaitingConnection.clear()
}
}
init {
@@ -159,6 +180,42 @@ class LibSignalChatConnection(
}
}
private fun sendRequestInternal(request: WebSocketRequestMessage, timeoutSeconds: Long, single: SingleSubject<WebsocketResponse>) {
CHAT_SERVICE_LOCK.withLock {
check(state.value == WebSocketConnectionState.CONNECTED)
val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds)
chatConnection!!.send(internalRequest)
.whenComplete(
onSuccess = { response ->
Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}")
when (response.status) {
in 400..599 -> {
healthMonitor.onMessageError(
status = response.status,
isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection
)
}
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection)))
},
onFailure = { throwable ->
Log.w(TAG, "$name [sendRequest] Failure:", throwable)
val downstreamThrowable = when (throwable) {
is ConnectionInvalidatedException -> NonSuccessfulResponseCodeException(4401)
// The clients of WebSocketConnection are often sensitive to the exact type of exception returned.
// This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when
// the connection fails before the request completes.
else -> SocketException("Failed to get response for request")
}
single.onError(downstreamThrowable)
}
)
}
}
override fun connect(): Observable<WebSocketConnectionState> {
CHAT_SERVICE_LOCK.withLock {
if (!isDead()) {
@@ -175,52 +232,87 @@ class LibSignalChatConnection(
// nullability concern here.
chatConnectionFuture!!.whenComplete(
onSuccess = { connection ->
CHAT_SERVICE_LOCK.withLock {
if (state.value == WebSocketConnectionState.CONNECTING) {
chatConnection = connection
connection?.start()
Log.i(TAG, "$name Connected")
state.onNext(WebSocketConnectionState.CONNECTED)
} else {
Log.i(TAG, "$name Dropped successful connection because we are now ${state.value}")
disconnect()
}
}
handleConnectionSuccess(connection!!)
},
onFailure = { throwable ->
CHAT_SERVICE_LOCK.withLock {
if (throwable is CancellationException) {
// We should have transitioned to DISCONNECTED immediately after we canceled chatConnectionFuture
check(state.value == WebSocketConnectionState.DISCONNECTED)
Log.i(TAG, "$name [connect] cancelled")
return@whenComplete
}
Log.w(TAG, "$name [connect] Failure:", throwable)
chatConnection = null
// Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT
// request returns HTTP 403.
// The chat service currently does not return HTTP 401 on /v1/websocket.
// Thus, this currently matches the implementation in OkHttpWebSocketConnection.
when (throwable) {
is DeviceDeregisteredException -> {
state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED)
}
is AppExpiredException -> {
state.onNext(WebSocketConnectionState.REMOTE_DEPRECATED)
}
else -> {
Log.w(TAG, "Unknown connection failure reason", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
}
}
handleConnectionFailure(throwable)
}
)
return state
}
}
private fun handleConnectionSuccess(connection: ChatConnection) {
CHAT_SERVICE_LOCK.withLock {
when (state.value) {
WebSocketConnectionState.CONNECTING -> {
chatConnection = connection
chatConnection?.start()
Log.i(TAG, "$name Connected")
state.onNext(WebSocketConnectionState.CONNECTED)
requestsAwaitingConnection.forEach { pending ->
runCatching {
sendRequestInternal(pending.request, pending.timeoutSeconds, pending.single)
}.onFailure { e ->
Log.w(TAG, "$name [sendRequest] Failed to send pending request", e)
pending.single.onError(SocketException("Closed unexpectedly"))
}
}
requestsAwaitingConnection.clear()
}
else -> {
Log.i(TAG, "$name Dropped successful connection because we are now ${state.value}")
disconnect()
}
}
}
}
private fun handleConnectionFailure(throwable: Throwable) {
CHAT_SERVICE_LOCK.withLock {
if (throwable is CancellationException) {
// We should have transitioned to DISCONNECTED immediately after we canceled chatConnectionFuture
check(state.value == WebSocketConnectionState.DISCONNECTED)
Log.i(TAG, "$name [connect] cancelled")
return
}
Log.w(TAG, "$name [connect] Failure:", throwable)
chatConnection = null
// Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT
// request returns HTTP 403.
// The chat service currently does not return HTTP 401 on /v1/websocket.
// Thus, this currently matches the implementation in OkHttpWebSocketConnection.
when (throwable) {
is DeviceDeregisteredException -> {
state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED)
}
is AppExpiredException -> {
state.onNext(WebSocketConnectionState.REMOTE_DEPRECATED)
}
else -> {
Log.w(TAG, "Unknown connection failure reason", throwable)
state.onNext(WebSocketConnectionState.FAILED)
}
}
val downstreamThrowable = when (throwable) {
is DeviceDeregisteredException -> NonSuccessfulResponseCodeException(403)
// This is just to match what OkHttpWebSocketConnection does in the case a pending request fails
// due to the underlying transport refusing to open.
else -> SocketException("Closed unexpectedly")
}
requestsAwaitingConnection.forEach { pending ->
pending.single.onError(downstreamThrowable)
}
requestsAwaitingConnection.clear()
}
}
override fun isDead(): Boolean {
CHAT_SERVICE_LOCK.withLock {
return when (state.value) {
@@ -231,8 +323,7 @@ class LibSignalChatConnection(
WebSocketConnectionState.REMOTE_DEPRECATED -> true
WebSocketConnectionState.CONNECTING,
WebSocketConnectionState.CONNECTED,
WebSocketConnectionState.RECONNECTING -> false
WebSocketConnectionState.CONNECTED -> false
null -> throw IllegalStateException("LibSignalChatConnection.state can never be null")
}
@@ -285,90 +376,20 @@ class LibSignalChatConnection(
val single = SingleSubject.create<WebsocketResponse>()
if (state.value == WebSocketConnectionState.CONNECTING) {
// In OkHttpWebSocketConnection, if a client calls sendRequest while we are still
// connecting to the Chat service, we queue the request to be sent after the
// the connection is established.
// We carry forward that behavior here, except we have to use future chaining
// rather than directly writing to the connection for it to buffer for us,
// because libsignal-net does not expose a connection handle until the connection
// is established.
Log.i(TAG, "[sendRequest] Enqueuing request send for after connection")
// We are in the CONNECTING state, so our invariant says that chatConnectionFuture should
// be set, so we should not have to worry about nullability here.
chatConnectionFuture!!.whenComplete(
onSuccess = {
// We depend on the libsignal's CompletableFuture's synchronization guarantee to
// keep this implementation simple. If another CompletableFuture implementation is
// used, we'll need to add some logic here to be ensure this completion handler
// fires after the one enqueued in connect().
try {
sendRequest(request).subscribe(
{ response ->
single.onSuccess(response)
},
{ error ->
single.onError(error)
}
)
} catch (e: IOException) {
// We failed to send the request because the connection closed between
// when we got the completion callback and when we got scheduled for
// execution. So, we need to propagate that error downstream, but we
// do not need to worry about pendingResponses, because the response
// single was never added to pendingResponses. (It is only added to
// the set after the request is *successfully* sent off.)
// There's also an additional complication that we know from in-the-field
// crash reports that some downstream consumer of the single's error
// call is not resilient to raw IOExceptions, so we need to again mirror
// the OkHttpWebSocketConnection behavior of passing an explicit
// SocketException instead.
single.onError(SocketException("Closed unexpectedly"))
}
},
onFailure = { throwable ->
// This matches the behavior of OkHttpWebSocketConnection when the connection fails
// before the buffered request can be sent.
val downstreamThrowable = when (throwable) {
is DeviceDeregisteredException -> NonSuccessfulResponseCodeException(403)
else -> SocketException("Closed unexpectedly")
}
single.onError(downstreamThrowable)
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds)
chatConnection!!.send(internalRequest)
.whenComplete(
onSuccess = { response ->
Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}")
when (response.status) {
in 400..599 -> {
healthMonitor.onMessageError(
status = response.status,
isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection
)
}
}
// Here success means "we received the response" even if it is reporting an error.
// This is consistent with the behavior of the OkHttpWebSocketConnection.
single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection)))
},
onFailure = { throwable ->
Log.w(TAG, "$name [sendRequest] Failure:", throwable)
val downstreamThrowable = when (throwable) {
is ConnectionInvalidatedException -> NonSuccessfulResponseCodeException(4401)
// The clients of WebSocketConnection are often sensitive to the exact type of exception returned.
// This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when
// the connection fails before the request completes.
else -> SocketException("Failed to get response for request")
}
single.onError(downstreamThrowable)
}
)
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
return when (state.value) {
WebSocketConnectionState.CONNECTING -> {
Log.i(TAG, "[sendRequest] Enqueuing request send for after connection")
requestsAwaitingConnection.add(RequestAwaitingConnection(request, timeoutSeconds, single))
single
}
WebSocketConnectionState.CONNECTED -> {
sendRequestInternal(request, timeoutSeconds, single)
single
}
else -> {
throw IllegalStateException("LibSignalChatConnection.state was neither dead, CONNECTING, or CONNECTED.")
}
}.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
}

View File

@@ -477,6 +477,103 @@ class LibSignalChatConnectionTest {
sendObserver.assertFailure(IOException().javaClass)
}
@Test
fun regressionTestSendAfterConnectionFutureCompletesButBeforeStateUpdates() {
// We used to have a race condition where if sendRequest was called after
// the chatConnectionFuture completed but before the completion handler that
// that updates LibSignalChatConnection's state ran, we would end up with a
// StackOverflowError exception.
// We ended up fixing that bug by refactoring that part of the code completely.
// This tests that scenario to ensure that we don't regress by introducing
// some other kind of bug in that tricky situation.
var connectionFuture: CompletableFuture<UnauthenticatedChatConnection>? = null
val futureCompletedLatch = CountDownLatch(1)
val requestCompletedLatch = CountDownLatch(1)
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
connectionFuture = CompletableFuture<UnauthenticatedChatConnection>()
// Add a completion handler that blocks to prevent state transition
connectionFuture!!.whenComplete { _, _ ->
// When we reach this point, we know connectionFuture.complete
// must have been called, and subsequent calls will return false.
futureCompletedLatch.countDown()
// Block to keep state as CONNECTING
requestCompletedLatch.await()
}
connectionFuture!!
}
connection.connect()
executor.submit {
// This will block until all the completion handlers complete, which
// means it will block until requestCompletedLatch is counted down.
connectionFuture!!.complete(chatConnection)
}
assertTrue("connectionFuture was never completed", futureCompletedLatch.await(100, TimeUnit.MILLISECONDS))
// Now calls to connectionFuture.whenComplete will synchronously
// execute the completionHandler given to them, but the state of
// LibSignalChatConnection will still be CONNECTING.
// Previously, this caused a bug where the completion handler would see
// the state was still CONNECTING, and call connectionFuture.whenComplete
// again, thus setting off an infinite recursive loop, ending in a
// StackOverflowError.
connection.sendRequest(WebSocketRequestMessage("GET", "/test"))
// The test passed! Unblock the executor thread.
requestCompletedLatch.countDown()
}
@Test
fun testQueueLargeNumberOfRequestsWhileConnecting() {
// Test queuing up 100,000 requests while the connection is still CONNECTING,
// then complete the connection to make sure they all send successfully.
var connectionCompletionFuture: CompletableFuture<UnauthenticatedChatConnection>? = null
val sendRequestCount = 100_000
val allSentLatch = CountDownLatch(sendRequestCount)
every { network.connectUnauthChat(any()) } answers {
chatListener = firstArg()
connectionCompletionFuture = CompletableFuture<UnauthenticatedChatConnection>()
connectionCompletionFuture!!
}
every { chatConnection.send(any()) } answers {
delay {
it.complete(RESPONSE_SUCCESS)
allSentLatch.countDown()
}
}
connection.connect()
val sendObservers = mutableListOf<TestObserver<WebsocketResponse>>()
for (i in 0 until sendRequestCount) {
val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/test-path-$i"))
val observer = sendSingle.test()
sendObservers.add(observer)
}
sendObservers.forEach { observer ->
observer.assertNotComplete()
}
connectionCompletionFuture!!.complete(chatConnection)
assertTrue("All $sendRequestCount were not sent", allSentLatch.await(1, TimeUnit.SECONDS))
sendObservers.forEach { observer ->
observer.awaitDone(100, TimeUnit.MILLISECONDS)
observer.assertValues(RESPONSE_SUCCESS.toWebsocketResponse(true))
observer.assertComplete()
}
}
private fun <T> delay(action: ((CompletableFuture<T>) -> Unit)): CompletableFuture<T> {
val future = CompletableFuture<T>()
executor.submit {