diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index bc0f3b4b44..420419d85a 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -127,15 +127,17 @@ class LibSignalChatConnection( val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) - val stateMonitor = state.subscribe { nextState -> - if (nextState == WebSocketConnectionState.DISCONNECTED) { - cleanup() - } + val stateMonitor = state + .skip(1) // Skip the transition to the initial DISCONNECTED state + .subscribe { nextState -> + if (nextState == WebSocketConnectionState.DISCONNECTED) { + cleanup() + } - CHAT_SERVICE_LOCK.withLock { - stateChangedOrMessageReceivedCondition.signalAll() + CHAT_SERVICE_LOCK.withLock { + stateChangedOrMessageReceivedCondition.signalAll() + } } - } private fun cleanup() { Log.i(TAG, "$name [cleanup]") @@ -243,10 +245,14 @@ class LibSignalChatConnection( chatConnection!!.disconnect() .whenComplete( onSuccess = { - Log.i(TAG, "$name Disconnected") - state.onNext(WebSocketConnectionState.DISCONNECTED) + // This future completion means the WebSocket close frame has been sent off, but we + // have not yet received a close frame back from the server. + // To match the behavior of OkHttpWebSocketConnection, we should transition to DISCONNECTED + // only when we get the close frame back from the server, which happens when + // onConnectionInterrupted is called. }, onFailure = { throwable -> + // We failed to write the close frame to the server? Something is very wrong, give up and tear down. Log.w(TAG, "$name Disconnect failed", throwable) state.onNext(WebSocketConnectionState.DISCONNECTED) } @@ -510,7 +516,13 @@ class LibSignalChatConnection( override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException?) { CHAT_SERVICE_LOCK.withLock { - Log.i(TAG, "$name connection interrupted", disconnectReason) + if (disconnectReason == null) { + // disconnectReason = null means we requested this disconnect earlier, and this is confirmation + // that disconnection is complete. + Log.i(TAG, "$name disconnected") + } else { + Log.i(TAG, "$name connection unexpectedly closed", disconnectReason) + } chatConnection = null state.onNext(WebSocketConnectionState.DISCONNECTED) } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index 1221a72f27..4330c73eba 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -8,6 +8,7 @@ import io.mockk.verify import io.reactivex.rxjava3.observers.TestObserver import okio.ByteString.Companion.toByteString import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotEquals import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Before @@ -109,6 +110,7 @@ class LibSignalChatConnectionTest { WebSocketConnectionState.CONNECTING, WebSocketConnectionState.CONNECTED ) + observer.assertNoConsecutiveDuplicates() } // Test that the LibSignalChatConnection transitions to FAILED if the @@ -139,6 +141,7 @@ class LibSignalChatConnectionTest { WebSocketConnectionState.CONNECTING, WebSocketConnectionState.FAILED ) + observer.assertNoConsecutiveDuplicates() } // Test connect followed by disconnect, checking the state transitions. @@ -157,6 +160,10 @@ class LibSignalChatConnectionTest { connection.disconnect() disconnectLatch!!.await(100, TimeUnit.MILLISECONDS) + // onConnectionInterrupted acts like the onClosed callback for the connection here, driving the + // transition from DISCONNECTING -> DISCONNECTED. + chatListener!!.onConnectionInterrupted(chatConnection, null) + observer.assertNotComplete() observer.assertValues( WebSocketConnectionState.DISCONNECTED, @@ -165,6 +172,7 @@ class LibSignalChatConnectionTest { WebSocketConnectionState.DISCONNECTING, WebSocketConnectionState.DISCONNECTED ) + observer.assertNoConsecutiveDuplicates() } // Test that a disconnect failure transitions from CONNECTED -> DISCONNECTING -> DISCONNECTED anyway, @@ -197,6 +205,7 @@ class LibSignalChatConnectionTest { WebSocketConnectionState.DISCONNECTING, WebSocketConnectionState.DISCONNECTED ) + observer.assertNoConsecutiveDuplicates() } // Test a successful keepAlive, i.e. we get a 200 OK in response to the keepAlive request, @@ -278,6 +287,7 @@ class LibSignalChatConnectionTest { // Disconnects as a result of keep-alive failure WebSocketConnectionState.DISCONNECTED ) + observer.assertNoConsecutiveDuplicates() verify(exactly = 0) { healthMonitor.onKeepAliveResponse(any(), any()) healthMonitor.onMessageError(any(), any()) @@ -303,6 +313,7 @@ class LibSignalChatConnectionTest { // Disconnects as a result of the connection interrupted event WebSocketConnectionState.DISCONNECTED ) + observer.assertNoConsecutiveDuplicates() verify(exactly = 0) { healthMonitor.onKeepAliveResponse(any(), any()) healthMonitor.onMessageError(any(), any()) @@ -472,6 +483,17 @@ class LibSignalChatConnectionTest { return future } + private fun TestObserver.assertNoConsecutiveDuplicates() { + val states = this.values() + for (i in 1 until states.size) { + assertNotEquals( + "Found duplicate consecutive states states[${i - 1}] = states[$i] = ${states[i]}", + states[i - 1], + states[i] + ) + } + } + companion object { // For verifying success / error scenarios in keepAlive tests, etc. private val RESPONSE_SUCCESS = ChatConnection.Response(200, "", emptyMap(), byteArrayOf())