diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt b/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt index dddb22c814..7c1a501aef 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/util/RemoteConfig.kt @@ -1005,7 +1005,7 @@ object RemoteConfig { @JvmStatic @get:JvmName("libSignalWebSocketEnabled") val libSignalWebSocketEnabled: Boolean by remoteValue( - key = "android.libsignalWebSocketEnabled.4", + key = "android.libsignalWebSocketEnabled.5", hotSwappable = false ) { value -> value.asBoolean(false) || Environment.IS_NIGHTLY diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/WebSocketConnectionState.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/WebSocketConnectionState.java index ce09f0f784..49b18fd220 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/WebSocketConnectionState.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/WebSocketConnectionState.java @@ -7,7 +7,6 @@ public enum WebSocketConnectionState { DISCONNECTED, CONNECTING, CONNECTED, - RECONNECTING, DISCONNECTING, AUTHENTICATION_FAILED, REMOTE_DEPRECATED, diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index c1f7876163..e58c9c29b4 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -81,8 +81,14 @@ class LibSignalChatConnection( private val nextIncomingMessageInternalPseudoId = AtomicLong(1) val ackSenderForInternalPseudoId = ConcurrentHashMap() - // CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, and - // chatConnectionFuture + private data class RequestAwaitingConnection( + val request: WebSocketRequestMessage, + val timeoutSeconds: Long, + val single: SingleSubject + ) + + // CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, + // chatConnectionFuture, and requestsAwaitingConnection. // stateChangedOrMessageReceivedCondition: derived from CHAT_SERVICE_LOCK, used by readRequest(), // exists to emulate idiosyncratic behavior of OkHttpWebSocketConnection for readRequest() // chatConnection: Set only when state == CONNECTED @@ -92,6 +98,10 @@ class LibSignalChatConnection( private var chatConnection: ChatConnection? = null private var chatConnectionFuture: CompletableFuture? = null + // requestsAwaitingConnection should only have contents when we are transitioning to, out of, or are + // in the CONNECTING state. + private val requestsAwaitingConnection = mutableListOf() + companion object { const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT" const val SERVICE_ENVELOPE_REQUEST_PATH = "/api/v1/message" @@ -133,11 +143,11 @@ class LibSignalChatConnection( val stateMonitor = state .skip(1) // Skip the transition to the initial DISCONNECTED state .subscribe { nextState -> - if (nextState == WebSocketConnectionState.DISCONNECTED) { - cleanup() - } - CHAT_SERVICE_LOCK.withLock { + if (nextState == WebSocketConnectionState.DISCONNECTED) { + cleanup() + } + stateChangedOrMessageReceivedCondition.signalAll() } } @@ -150,6 +160,17 @@ class LibSignalChatConnection( // there is no ackSender for a pseudoId gracefully in sendResponse. ackSenderForInternalPseudoId.clear() // There's no sense in resetting nextIncomingMessageInternalPseudoId. + + // This is a belt-and-suspenders check, because the transition handler leaving the CONNECTING + // state should always cleanup the requestsAwaitingConnection, but in case we miss one, log it + // as an error and clean it up gracefully + if (requestsAwaitingConnection.isNotEmpty()) { + Log.w(TAG, "$name [cleanup] ${requestsAwaitingConnection.size} requestsAwaitingConnection during cleanup! This is probably a bug.") + requestsAwaitingConnection.forEach { pending -> + pending.single.onError(SocketException("Connection terminated unexpectedly")) + } + requestsAwaitingConnection.clear() + } } init { @@ -159,6 +180,42 @@ class LibSignalChatConnection( } } + private fun sendRequestInternal(request: WebSocketRequestMessage, timeoutSeconds: Long, single: SingleSubject) { + CHAT_SERVICE_LOCK.withLock { + check(state.value == WebSocketConnectionState.CONNECTED) + + val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds) + chatConnection!!.send(internalRequest) + .whenComplete( + onSuccess = { response -> + Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}") + when (response.status) { + in 400..599 -> { + healthMonitor.onMessageError( + status = response.status, + isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection + ) + } + } + // Here success means "we received the response" even if it is reporting an error. + // This is consistent with the behavior of the OkHttpWebSocketConnection. + single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection))) + }, + onFailure = { throwable -> + Log.w(TAG, "$name [sendRequest] Failure:", throwable) + val downstreamThrowable = when (throwable) { + is ConnectionInvalidatedException -> NonSuccessfulResponseCodeException(4401) + // The clients of WebSocketConnection are often sensitive to the exact type of exception returned. + // This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when + // the connection fails before the request completes. + else -> SocketException("Failed to get response for request") + } + single.onError(downstreamThrowable) + } + ) + } + } + override fun connect(): Observable { CHAT_SERVICE_LOCK.withLock { if (!isDead()) { @@ -175,52 +232,87 @@ class LibSignalChatConnection( // nullability concern here. chatConnectionFuture!!.whenComplete( onSuccess = { connection -> - CHAT_SERVICE_LOCK.withLock { - if (state.value == WebSocketConnectionState.CONNECTING) { - chatConnection = connection - connection?.start() - Log.i(TAG, "$name Connected") - state.onNext(WebSocketConnectionState.CONNECTED) - } else { - Log.i(TAG, "$name Dropped successful connection because we are now ${state.value}") - disconnect() - } - } + handleConnectionSuccess(connection!!) }, onFailure = { throwable -> - CHAT_SERVICE_LOCK.withLock { - if (throwable is CancellationException) { - // We should have transitioned to DISCONNECTED immediately after we canceled chatConnectionFuture - check(state.value == WebSocketConnectionState.DISCONNECTED) - Log.i(TAG, "$name [connect] cancelled") - return@whenComplete - } - - Log.w(TAG, "$name [connect] Failure:", throwable) - chatConnection = null - // Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT - // request returns HTTP 403. - // The chat service currently does not return HTTP 401 on /v1/websocket. - // Thus, this currently matches the implementation in OkHttpWebSocketConnection. - when (throwable) { - is DeviceDeregisteredException -> { - state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED) - } - is AppExpiredException -> { - state.onNext(WebSocketConnectionState.REMOTE_DEPRECATED) - } - else -> { - Log.w(TAG, "Unknown connection failure reason", throwable) - state.onNext(WebSocketConnectionState.FAILED) - } - } - } + handleConnectionFailure(throwable) } ) return state } } + private fun handleConnectionSuccess(connection: ChatConnection) { + CHAT_SERVICE_LOCK.withLock { + when (state.value) { + WebSocketConnectionState.CONNECTING -> { + chatConnection = connection + chatConnection?.start() + Log.i(TAG, "$name Connected") + state.onNext(WebSocketConnectionState.CONNECTED) + + requestsAwaitingConnection.forEach { pending -> + runCatching { + sendRequestInternal(pending.request, pending.timeoutSeconds, pending.single) + }.onFailure { e -> + Log.w(TAG, "$name [sendRequest] Failed to send pending request", e) + pending.single.onError(SocketException("Closed unexpectedly")) + } + } + + requestsAwaitingConnection.clear() + } + else -> { + Log.i(TAG, "$name Dropped successful connection because we are now ${state.value}") + disconnect() + } + } + } + } + + private fun handleConnectionFailure(throwable: Throwable) { + CHAT_SERVICE_LOCK.withLock { + if (throwable is CancellationException) { + // We should have transitioned to DISCONNECTED immediately after we canceled chatConnectionFuture + check(state.value == WebSocketConnectionState.DISCONNECTED) + Log.i(TAG, "$name [connect] cancelled") + return + } + + Log.w(TAG, "$name [connect] Failure:", throwable) + chatConnection = null + + // Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT + // request returns HTTP 403. + // The chat service currently does not return HTTP 401 on /v1/websocket. + // Thus, this currently matches the implementation in OkHttpWebSocketConnection. + when (throwable) { + is DeviceDeregisteredException -> { + state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED) + } + is AppExpiredException -> { + state.onNext(WebSocketConnectionState.REMOTE_DEPRECATED) + } + else -> { + Log.w(TAG, "Unknown connection failure reason", throwable) + state.onNext(WebSocketConnectionState.FAILED) + } + } + + val downstreamThrowable = when (throwable) { + is DeviceDeregisteredException -> NonSuccessfulResponseCodeException(403) + // This is just to match what OkHttpWebSocketConnection does in the case a pending request fails + // due to the underlying transport refusing to open. + else -> SocketException("Closed unexpectedly") + } + + requestsAwaitingConnection.forEach { pending -> + pending.single.onError(downstreamThrowable) + } + requestsAwaitingConnection.clear() + } + } + override fun isDead(): Boolean { CHAT_SERVICE_LOCK.withLock { return when (state.value) { @@ -231,8 +323,7 @@ class LibSignalChatConnection( WebSocketConnectionState.REMOTE_DEPRECATED -> true WebSocketConnectionState.CONNECTING, - WebSocketConnectionState.CONNECTED, - WebSocketConnectionState.RECONNECTING -> false + WebSocketConnectionState.CONNECTED -> false null -> throw IllegalStateException("LibSignalChatConnection.state can never be null") } @@ -285,90 +376,20 @@ class LibSignalChatConnection( val single = SingleSubject.create() - if (state.value == WebSocketConnectionState.CONNECTING) { - // In OkHttpWebSocketConnection, if a client calls sendRequest while we are still - // connecting to the Chat service, we queue the request to be sent after the - // the connection is established. - // We carry forward that behavior here, except we have to use future chaining - // rather than directly writing to the connection for it to buffer for us, - // because libsignal-net does not expose a connection handle until the connection - // is established. - Log.i(TAG, "[sendRequest] Enqueuing request send for after connection") - // We are in the CONNECTING state, so our invariant says that chatConnectionFuture should - // be set, so we should not have to worry about nullability here. - chatConnectionFuture!!.whenComplete( - onSuccess = { - // We depend on the libsignal's CompletableFuture's synchronization guarantee to - // keep this implementation simple. If another CompletableFuture implementation is - // used, we'll need to add some logic here to be ensure this completion handler - // fires after the one enqueued in connect(). - try { - sendRequest(request).subscribe( - { response -> - single.onSuccess(response) - }, - { error -> - single.onError(error) - } - ) - } catch (e: IOException) { - // We failed to send the request because the connection closed between - // when we got the completion callback and when we got scheduled for - // execution. So, we need to propagate that error downstream, but we - // do not need to worry about pendingResponses, because the response - // single was never added to pendingResponses. (It is only added to - // the set after the request is *successfully* sent off.) - // There's also an additional complication that we know from in-the-field - // crash reports that some downstream consumer of the single's error - // call is not resilient to raw IOExceptions, so we need to again mirror - // the OkHttpWebSocketConnection behavior of passing an explicit - // SocketException instead. - single.onError(SocketException("Closed unexpectedly")) - } - }, - onFailure = { throwable -> - // This matches the behavior of OkHttpWebSocketConnection when the connection fails - // before the buffered request can be sent. - val downstreamThrowable = when (throwable) { - is DeviceDeregisteredException -> NonSuccessfulResponseCodeException(403) - else -> SocketException("Closed unexpectedly") - } - single.onError(downstreamThrowable) - } - ) - return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) - } - - val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds) - chatConnection!!.send(internalRequest) - .whenComplete( - onSuccess = { response -> - Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}") - when (response.status) { - in 400..599 -> { - healthMonitor.onMessageError( - status = response.status, - isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection - ) - } - } - // Here success means "we received the response" even if it is reporting an error. - // This is consistent with the behavior of the OkHttpWebSocketConnection. - single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection))) - }, - onFailure = { throwable -> - Log.w(TAG, "$name [sendRequest] Failure:", throwable) - val downstreamThrowable = when (throwable) { - is ConnectionInvalidatedException -> NonSuccessfulResponseCodeException(4401) - // The clients of WebSocketConnection are often sensitive to the exact type of exception returned. - // This is the exception that OkHttpWebSocketConnection throws in the closest scenario to this, when - // the connection fails before the request completes. - else -> SocketException("Failed to get response for request") - } - single.onError(downstreamThrowable) - } - ) - return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) + return when (state.value) { + WebSocketConnectionState.CONNECTING -> { + Log.i(TAG, "[sendRequest] Enqueuing request send for after connection") + requestsAwaitingConnection.add(RequestAwaitingConnection(request, timeoutSeconds, single)) + single + } + WebSocketConnectionState.CONNECTED -> { + sendRequestInternal(request, timeoutSeconds, single) + single + } + else -> { + throw IllegalStateException("LibSignalChatConnection.state was neither dead, CONNECTING, or CONNECTED.") + } + }.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) } } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index 19d1780671..d89ce140e1 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -477,6 +477,103 @@ class LibSignalChatConnectionTest { sendObserver.assertFailure(IOException().javaClass) } + @Test + fun regressionTestSendAfterConnectionFutureCompletesButBeforeStateUpdates() { + // We used to have a race condition where if sendRequest was called after + // the chatConnectionFuture completed but before the completion handler that + // that updates LibSignalChatConnection's state ran, we would end up with a + // StackOverflowError exception. + // We ended up fixing that bug by refactoring that part of the code completely. + // This tests that scenario to ensure that we don't regress by introducing + // some other kind of bug in that tricky situation. + var connectionFuture: CompletableFuture? = null + val futureCompletedLatch = CountDownLatch(1) + val requestCompletedLatch = CountDownLatch(1) + + every { network.connectUnauthChat(any()) } answers { + chatListener = firstArg() + connectionFuture = CompletableFuture() + + // Add a completion handler that blocks to prevent state transition + connectionFuture!!.whenComplete { _, _ -> + // When we reach this point, we know connectionFuture.complete + // must have been called, and subsequent calls will return false. + futureCompletedLatch.countDown() + // Block to keep state as CONNECTING + requestCompletedLatch.await() + } + + connectionFuture!! + } + + connection.connect() + + executor.submit { + // This will block until all the completion handlers complete, which + // means it will block until requestCompletedLatch is counted down. + connectionFuture!!.complete(chatConnection) + } + + assertTrue("connectionFuture was never completed", futureCompletedLatch.await(100, TimeUnit.MILLISECONDS)) + + // Now calls to connectionFuture.whenComplete will synchronously + // execute the completionHandler given to them, but the state of + // LibSignalChatConnection will still be CONNECTING. + // Previously, this caused a bug where the completion handler would see + // the state was still CONNECTING, and call connectionFuture.whenComplete + // again, thus setting off an infinite recursive loop, ending in a + // StackOverflowError. + connection.sendRequest(WebSocketRequestMessage("GET", "/test")) + + // The test passed! Unblock the executor thread. + requestCompletedLatch.countDown() + } + + @Test + fun testQueueLargeNumberOfRequestsWhileConnecting() { + // Test queuing up 100,000 requests while the connection is still CONNECTING, + // then complete the connection to make sure they all send successfully. + var connectionCompletionFuture: CompletableFuture? = null + val sendRequestCount = 100_000 + val allSentLatch = CountDownLatch(sendRequestCount) + + every { network.connectUnauthChat(any()) } answers { + chatListener = firstArg() + connectionCompletionFuture = CompletableFuture() + connectionCompletionFuture!! + } + + every { chatConnection.send(any()) } answers { + delay { + it.complete(RESPONSE_SUCCESS) + allSentLatch.countDown() + } + } + + connection.connect() + + val sendObservers = mutableListOf>() + for (i in 0 until sendRequestCount) { + val sendSingle = connection.sendRequest(WebSocketRequestMessage("GET", "/test-path-$i")) + val observer = sendSingle.test() + sendObservers.add(observer) + } + + sendObservers.forEach { observer -> + observer.assertNotComplete() + } + + connectionCompletionFuture!!.complete(chatConnection) + + assertTrue("All $sendRequestCount were not sent", allSentLatch.await(1, TimeUnit.SECONDS)) + + sendObservers.forEach { observer -> + observer.awaitDone(100, TimeUnit.MILLISECONDS) + observer.assertValues(RESPONSE_SUCCESS.toWebsocketResponse(true)) + observer.assertComplete() + } + } + private fun delay(action: ((CompletableFuture) -> Unit)): CompletableFuture { val future = CompletableFuture() executor.submit {