From 2186e2bf921b298ec180db3a76db6bcd9369b8f3 Mon Sep 17 00:00:00 2001 From: andrew-signal Date: Tue, 4 Feb 2025 14:34:07 -0500 Subject: [PATCH] Update LibSignalChatConnection to use new ChatConnection API rather than ChatService --- .../ApplicationDependencyProvider.java | 2 +- .../websocket/LibSignalChatConnection.kt | 115 +++++--- .../websocket/LibSignalNetworkExtensions.kt | 20 -- .../websocket/ShadowingWebSocketConnection.kt | 57 ++-- .../websocket/LibSignalChatConnectionTest.kt | 277 +++++++----------- 5 files changed, 218 insertions(+), 253 deletions(-) diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 8e46c49856..37b75cc342 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -438,7 +438,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { BuildConfig.SIGNAL_AGENT, healthMonitor, Stories.isFeatureEnabled(), - LibSignalNetworkExtensions.createChatService(libSignalNetworkSupplier.get(), null, Stories.isFeatureEnabled(), null), + libSignalNetworkSupplier.get(), shadowPercentage, bridge ); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 588f4fe9fd..671bd7787e 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -12,14 +12,16 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject import io.reactivex.rxjava3.subjects.SingleSubject import okio.ByteString import okio.ByteString.Companion.toByteString +import okio.withLock import org.signal.core.util.logging.Log -import org.signal.libsignal.net.AuthenticatedChatService -import org.signal.libsignal.net.ChatListener -import org.signal.libsignal.net.ChatService +import org.signal.libsignal.internal.CompletableFuture +import org.signal.libsignal.net.AuthenticatedChatConnection +import org.signal.libsignal.net.ChatConnection +import org.signal.libsignal.net.ChatConnectionListener import org.signal.libsignal.net.ChatServiceException import org.signal.libsignal.net.DeviceDeregisteredException import org.signal.libsignal.net.Network -import org.signal.libsignal.net.UnauthenticatedChatService +import org.signal.libsignal.net.UnauthenticatedChatConnection import org.whispersystems.signalservice.api.util.CredentialsProvider import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState @@ -35,21 +37,20 @@ import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock import kotlin.time.Duration.Companion.seconds -import org.signal.libsignal.net.ChatService.Request as LibSignalRequest -import org.signal.libsignal.net.ChatService.Response as LibSignalResponse +import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest +import org.signal.libsignal.net.ChatConnection.Response as LibSignalResponse /** * Implements the WebSocketConnection interface via libsignal-net * * Notable implementation choices: - * - [chatService] contains both the authenticated and unauthenticated connections, - * which one to use for [sendRequest]/[sendResponse] is based on [isAuthenticated]. - * - keep-alive requests always use the [org.signal.libsignal.net.ChatService.unauthenticatedSendAndDebug] - * API, and log the debug info on success. - * - regular sends use [org.signal.libsignal.net.ChatService.unauthenticatedSend] and don't create any overhead. + * - [chatConnection] contains either an authenticated or an unauthenticated connections + * - keep-alive requests are sent on both authenticated and unauthenticated connections, mirroring the existing OkHttp behavior * - [org.whispersystems.signalservice.api.websocket.WebSocketConnectionState] reporting is implemented * as close as possible to the original implementation in * [org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection]. + * - we expose fake "psuedo IDs" for incoming requests so the layer on top of ours can work with IDs, just + * like with the old OkHttp implementation, and internally we map these IDs to AckSenders */ class LibSignalChatConnection( name: String, @@ -73,10 +74,10 @@ class LibSignalChatConnection( // tell us to send a response for that ID, and then we use the pseudo ID as a handle to find // the callback given to us earlier by libsignal-net, and we call that callback. private val nextIncomingMessageInternalPseudoId = AtomicLong(1) - val ackSenderForInternalPseudoId = ConcurrentHashMap() + val ackSenderForInternalPseudoId = ConcurrentHashMap() private val CHAT_SERVICE_LOCK = ReentrantLock() - private var chatService: ChatService? = null + private var chatConnection: ChatConnection? = null companion object { const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT" @@ -142,23 +143,38 @@ class LibSignalChatConnection( // There's no sense in resetting nextIncomingMessageInternalPseudoId. } + init { + if (credentialsProvider != null) { + check(!credentialsProvider.username.isNullOrEmpty()) + check(!credentialsProvider.password.isNullOrEmpty()) + } + } + override fun connect(): Observable { CHAT_SERVICE_LOCK.withLock { - if (chatService != null) { + if (!isDead()) { return state } - Log.i(TAG, "$name Connecting...") - chatService = network.createChatService(credentialsProvider, receiveStories, listener).apply { - state.onNext(WebSocketConnectionState.CONNECTING) - connect().whenComplete( - onSuccess = { debugInfo -> + val chatConnectionFuture: CompletableFuture = if (credentialsProvider == null) { + network.connectUnauthChat(listener) + } else { + network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener) + } + state.onNext(WebSocketConnectionState.CONNECTING) + chatConnectionFuture.whenComplete( + onSuccess = { connection -> + CHAT_SERVICE_LOCK.withLock { + chatConnection = connection + connection?.start() Log.i(TAG, "$name Connected") - Log.d(TAG, "$name $debugInfo") state.onNext(WebSocketConnectionState.CONNECTED) - }, - onFailure = { throwable -> + } + }, + onFailure = { throwable -> + CHAT_SERVICE_LOCK.withLock { Log.w(TAG, "$name [connect] Failure:", throwable) + chatConnection = null // Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT // request returns HTTP 403. // The chat service currently does not return HTTP 401 on /v1/websocket. @@ -169,27 +185,38 @@ class LibSignalChatConnection( state.onNext(WebSocketConnectionState.FAILED) } } - ) - } + } + ) return state } } override fun isDead(): Boolean { CHAT_SERVICE_LOCK.withLock { - return chatService == null + return when (state.value) { + WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.DISCONNECTING, + WebSocketConnectionState.FAILED, + WebSocketConnectionState.AUTHENTICATION_FAILED -> true + + WebSocketConnectionState.CONNECTING, + WebSocketConnectionState.CONNECTED, + WebSocketConnectionState.RECONNECTING -> false + + null -> throw IllegalStateException("LibSignalChatConnection.state can never be null") + } } } override fun disconnect() { CHAT_SERVICE_LOCK.withLock { - if (chatService == null) { + if (isDead()) { return } Log.i(TAG, "$name Disconnecting...") state.onNext(WebSocketConnectionState.DISCONNECTING) - chatService!!.disconnect() + chatConnection!!.disconnect() .whenComplete( onSuccess = { Log.i(TAG, "$name Disconnected") @@ -200,18 +227,18 @@ class LibSignalChatConnection( state.onNext(WebSocketConnectionState.DISCONNECTED) } ) - chatService = null + chatConnection = null } } override fun sendRequest(request: WebSocketRequestMessage): Single { CHAT_SERVICE_LOCK.withLock { - if (chatService == null) { + if (isDead()) { return Single.error(IOException("$name is closed!")) } val single = SingleSubject.create() val internalRequest = request.toLibSignalRequest() - chatService!!.send(internalRequest) + chatConnection!!.send(internalRequest) .whenComplete( onSuccess = { response -> Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}") @@ -219,13 +246,13 @@ class LibSignalChatConnection( in 400..599 -> { healthMonitor.onMessageError( status = response.status, - isIdentifiedWebSocket = chatService is AuthenticatedChatService + isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection ) } } // Here success means "we received the response" even if it is reporting an error. // This is consistent with the behavior of the OkHttpWebSocketConnection. - single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) + single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatConnection is UnauthenticatedChatConnection))) }, onFailure = { throwable -> Log.w(TAG, "$name [sendRequest] Failure:", throwable) @@ -238,29 +265,29 @@ class LibSignalChatConnection( override fun sendKeepAlive() { CHAT_SERVICE_LOCK.withLock { - if (chatService == null) { + if (isDead()) { return } Log.i(TAG, "$name Sending keep alive...") - chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST) + chatConnection!!.send(KEEP_ALIVE_REQUEST) .whenComplete( - onSuccess = { debugResponse -> + onSuccess = { response -> Log.d(TAG, "$name [sendKeepAlive] Success") - when (debugResponse!!.response.status) { + when (response!!.status) { in 200..299 -> { healthMonitor.onKeepAliveResponse( sentTimestamp = Instant.now().toEpochMilli(), // ignored. can be any value - isIdentifiedWebSocket = chatService is AuthenticatedChatService + isIdentifiedWebSocket = chatConnection is AuthenticatedChatConnection ) } in 400..599 -> { - healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService)) + healthMonitor.onMessageError(response.status, (chatConnection is AuthenticatedChatConnection)) } else -> { - Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${debugResponse.response.status}") + Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${response.status}") } } }, @@ -310,8 +337,8 @@ class LibSignalChatConnection( private val listener = LibSignalChatListener() - private inner class LibSignalChatListener : ChatListener { - override fun onIncomingMessage(chat: ChatService, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) { + private inner class LibSignalChatListener : ChatConnectionListener { + override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) { // NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is // already in the ackSender map, if it exists. val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement() @@ -328,15 +355,15 @@ class LibSignalChatConnection( incomingRequestQueue.put(incomingWebSocketRequest) } - override fun onConnectionInterrupted(chat: ChatService, disconnectReason: ChatServiceException) { + override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException) { CHAT_SERVICE_LOCK.withLock { Log.i(TAG, "$name connection interrupted", disconnectReason) - chatService = null + chatConnection = null state.onNext(WebSocketConnectionState.DISCONNECTED) } } - override fun onQueueEmpty(chat: ChatService) { + override fun onQueueEmpty(chat: ChatConnection) { val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement() val queueEmptyRequest = WebSocketRequestMessage( verb = SOCKET_EMPTY_REQUEST_VERB, diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetworkExtensions.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetworkExtensions.kt index e3b8d80fd1..26210afd73 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetworkExtensions.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetworkExtensions.kt @@ -7,29 +7,9 @@ package org.whispersystems.signalservice.internal.websocket import org.signal.core.util.orNull -import org.signal.libsignal.net.ChatListener -import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.Network -import org.whispersystems.signalservice.api.util.CredentialsProvider import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration -/** - * Helper method to create a ChatService with optional credentials. - */ -fun Network.createChatService( - credentialsProvider: CredentialsProvider? = null, - receiveStories: Boolean, - listener: ChatListener? = null -): ChatService { - val username = credentialsProvider?.username ?: "" - val password = credentialsProvider?.password ?: "" - return if (username.isEmpty() && password.isEmpty()) { - this.createUnauthChatService(listener) - } else { - this.createAuthChatService(username, password, receiveStories, listener) - } -} - /** * Helper method to apply settings from the SignalServiceConfiguration. */ diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/ShadowingWebSocketConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/ShadowingWebSocketConnection.kt index 999cfee311..58d43fe438 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/ShadowingWebSocketConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/ShadowingWebSocketConnection.kt @@ -10,7 +10,9 @@ import io.reactivex.rxjava3.core.Single import okhttp3.Response import okhttp3.WebSocket import org.signal.core.util.logging.Log -import org.signal.libsignal.net.ChatService +import org.signal.libsignal.net.ChatConnection +import org.signal.libsignal.net.Network +import org.signal.libsignal.net.UnauthenticatedChatConnection import org.whispersystems.signalservice.api.util.CredentialsProvider import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState @@ -46,7 +48,7 @@ class ShadowingWebSocketConnection( signalAgent: String, healthMonitor: HealthMonitor, allowStories: Boolean, - private val chatService: ChatService, + private val network: Network, private val shadowPercentage: Int, private val bridge: WebSocketShadowingBridge ) : OkHttpWebSocketConnection( @@ -67,19 +69,33 @@ class ShadowingWebSocketConnection( } private val canShadow: AtomicBoolean = AtomicBoolean(false) private val executor: ExecutorService = Executors.newSingleThreadExecutor() + private var chatConnection: UnauthenticatedChatConnection? = null + private var shadowingConnectPending = false override fun connect(): Observable { - executor.submit { - chatService.connect().whenComplete( - onSuccess = { - canShadow.set(true) - Log.i(TAG, "Shadow socket connected.") - }, - onFailure = { - canShadow.set(false) - Log.i(TAG, "Shadow socket failed to connect.") - } - ) + // NB: The potential for race conditions here was introduced when we switched from ChatService's + // long lived connection model to the single-use ChatConnection model. + // At this time, we do not intend to ever use this code in production again, so I'm deferring properly + // fixing it with a refactor, and instead just doing the bare minimum to avoid an obvious race. + // If we do want to use this again in production, we should probably refactor to depend on the higher level + // LibSignalChatConnection, rather than the lower level ChatConnection API. + if (chatConnection == null && !shadowingConnectPending) { + shadowingConnectPending = true + executor.submit { + network.connectUnauthChat(null).whenComplete( + onSuccess = { connection -> + shadowingConnectPending = false + chatConnection = connection + canShadow.set(true) + Log.i(TAG, "Shadow socket connected.") + }, + onFailure = { + shadowingConnectPending = false + canShadow.set(false) + Log.i(TAG, "Shadow socket failed to connect.") + } + ) + } } return super.connect() } @@ -96,7 +112,7 @@ class ShadowingWebSocketConnection( override fun disconnect() { executor.submit { - chatService.disconnect().thenApply { + chatConnection?.disconnect()?.thenApply { canShadow.set(false) Log.i(TAG, "Shadow socket disconnected.") } @@ -133,22 +149,23 @@ class ShadowingWebSocketConnection( } private fun libsignalKeepAlive(actualResponse: WebsocketResponse) { - val request = ChatService.Request( + val connection = chatConnection ?: return + val request = ChatConnection.Request( "GET", "/v1/keepalive", emptyMap(), ByteArray(0), KEEP_ALIVE_TIMEOUT.inWholeMilliseconds.toInt() ) - chatService.sendAndDebug(request) - .whenComplete( - onSuccess = { + connection.send(request) + ?.whenComplete( + onSuccess = { response -> stats.requestsCompared.incrementAndGet() - val goodStatus = (it?.response?.status ?: -1) in 200..299 + val goodStatus = (response?.status ?: -1) in 200..299 if (!goodStatus) { stats.badStatuses.incrementAndGet() } - Log.i(TAG, "$it") + Log.i(TAG, response?.message) }, onFailure = { stats.requestsCompared.incrementAndGet() diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index db67a8d0ab..bbb0303c25 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -1,9 +1,9 @@ package org.whispersystems.signalservice.internal.websocket import io.mockk.clearAllMocks +import io.mockk.clearMocks import io.mockk.every import io.mockk.mockk -import io.mockk.mockkStatic import io.mockk.verify import io.reactivex.rxjava3.observers.TestObserver import okio.ByteString.Companion.toByteString @@ -12,12 +12,11 @@ import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test import org.signal.libsignal.internal.CompletableFuture -import org.signal.libsignal.net.ChatListener -import org.signal.libsignal.net.ChatService -import org.signal.libsignal.net.ChatService.DebugInfo +import org.signal.libsignal.net.ChatConnection +import org.signal.libsignal.net.ChatConnectionListener import org.signal.libsignal.net.ChatServiceException -import org.signal.libsignal.net.IpType import org.signal.libsignal.net.Network +import org.signal.libsignal.net.UnauthenticatedChatConnection import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import java.util.concurrent.CountDownLatch @@ -25,51 +24,76 @@ import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException -import org.signal.libsignal.net.ChatService.Response as LibSignalResponse -import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse class LibSignalChatConnectionTest { private val executor: ExecutorService = Executors.newSingleThreadExecutor() private val healthMonitor = mockk() - private val chatService = mockk() private val network = mockk() private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor) - private var chatListener: ChatListener? = null + private val chatConnection = mockk() + private var chatListener: ChatConnectionListener? = null + + // Used by default-success mocks for ChatConnection behavior. + private var connectLatch: CountDownLatch? = null + private var disconnectLatch: CountDownLatch? = null + private var sendLatch: CountDownLatch? = null + + private fun setupConnectedConnection() { + connectLatch = CountDownLatch(1) + connection.connect() + connectLatch!!.await(100, TimeUnit.MILLISECONDS) + } @Before fun before() { clearAllMocks() - mockkStatic(Network::createChatService) every { healthMonitor.onMessageError(any(), any()) } every { healthMonitor.onKeepAliveResponse(any(), any()) } - every { network.createChatService(any(), any(), any()) } answers { - // When mocking static methods in mockk, the mock target is included as the first - // argument in the answers block. This results in the thirdArgument() convenience method - // being off-by-one. Since we are interested in the last argument to createChatService, we need - // to manually fetch it from the args array and cast it ourselves. - chatListener = args[3] as ChatListener? - chatService - } - } - @Test - fun orderOfStatesOnSuccessfulConnect() { - val latch = CountDownLatch(1) - - every { chatService.connect() } answers { + // NB: We provide default success behavior mocks here to cut down on boilerplate later, but it is + // expected that some tests will override some of these to test failures. + // + // We provide a null credentials provider when creating `connection`, so LibSignalChatConnection + // should always call connectUnauthChat() + // TODO: Maybe also test Auth? The old one didn't. + every { network.connectUnauthChat(any()) } answers { + chatListener = firstArg() delay { - it.complete(DEBUG_INFO) - latch.countDown() + it.complete(chatConnection) + connectLatch?.countDown() } } + every { chatConnection.disconnect() } answers { + delay { + it.complete(null) + disconnectLatch?.countDown() + } + } + + every { chatConnection.send(any()) } answers { + delay { + it.complete(RESPONSE_SUCCESS) + sendLatch?.countDown() + } + } + + every { chatConnection.start() } returns Unit + } + + // Test that the LibSignalChatConnection transitions through DISCONNECTED -> CONNECTING -> CONNECTED + // if the underlying ChatConnection future completes successfully. + @Test + fun orderOfStatesOnSuccessfulConnect() { + connectLatch = CountDownLatch(1) + val observer = TestObserver() connection.state.subscribe(observer) connection.connect() - latch.await(100, TimeUnit.MILLISECONDS) + connectLatch!!.await(100, TimeUnit.MILLISECONDS) observer.assertNotComplete() observer.assertValues( @@ -79,14 +103,18 @@ class LibSignalChatConnectionTest { ) } + // Test that the LibSignalChatConnection transitions to FAILED if the + // underlying ChatConnection future completes exceptionally. @Test fun orderOfStatesOnConnectionFailure() { val connectionException = RuntimeException("connect failed") val latch = CountDownLatch(1) - every { chatService.connect() } answers { + every { network.connectUnauthChat(any()) } answers { + chatListener = firstArg() delay { it.completeExceptionally(connectionException) + latch.countDown() } } @@ -105,32 +133,21 @@ class LibSignalChatConnectionTest { ) } + // Test connect followed by disconnect, checking the state transitions. @Test fun orderOfStatesOnConnectAndDisconnect() { - val connectLatch = CountDownLatch(1) - val disconnectLatch = CountDownLatch(1) - - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - every { chatService.disconnect() } answers { - delay { - it.complete(null) - disconnectLatch.countDown() - } - } + connectLatch = CountDownLatch(1) + disconnectLatch = CountDownLatch(1) val observer = TestObserver() connection.state.subscribe(observer) connection.connect() - connectLatch.await(100, TimeUnit.MILLISECONDS) + connectLatch!!.await(100, TimeUnit.MILLISECONDS) + connection.disconnect() - disconnectLatch.await(100, TimeUnit.MILLISECONDS) + disconnectLatch!!.await(100, TimeUnit.MILLISECONDS) observer.assertNotComplete() observer.assertValues( @@ -142,30 +159,21 @@ class LibSignalChatConnectionTest { ) } + // Test that a disconnect failure transitions from CONNECTED -> DISCONNECTING -> DISCONNECTED anyway, + // since we don't have a specific "DISCONNECT_FAILED" state. @Test fun orderOfStatesOnDisconnectFailure() { val disconnectException = RuntimeException("disconnect failed") - - val connectLatch = CountDownLatch(1) val disconnectLatch = CountDownLatch(1) - every { chatService.disconnect() } answers { + every { chatConnection.disconnect() } answers { delay { it.completeExceptionally(disconnectException) disconnectLatch.countDown() } } - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - - connection.connect() - - connectLatch.await(100, TimeUnit.MILLISECONDS) + setupConnectedConnection() val observer = TestObserver() connection.state.subscribe(observer) @@ -176,34 +184,23 @@ class LibSignalChatConnectionTest { observer.assertNotComplete() observer.assertValues( + // The subscriber is created after we've already connected, so the first state it sees is CONNECTED: WebSocketConnectionState.CONNECTED, WebSocketConnectionState.DISCONNECTING, WebSocketConnectionState.DISCONNECTED ) } + // Test a successful keepAlive, i.e. we get a 200 OK in response to the keepAlive request, + // which triggers healthMonitor.onKeepAliveResponse(...) and not onMessageError. @Test fun keepAliveSuccess() { - val latch = CountDownLatch(1) + setupConnectedConnection() - every { chatService.sendAndDebug(any()) } answers { - delay { - it.complete(make_debug_response(RESPONSE_SUCCESS)) - latch.countDown() - } - } - - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - } - } - - connection.connect() + sendLatch = CountDownLatch(1) connection.sendKeepAlive() - - latch.await(100, TimeUnit.MILLISECONDS) + sendLatch!!.await(100, TimeUnit.MILLISECONDS) verify(exactly = 1) { healthMonitor.onKeepAliveResponse(any(), false) @@ -213,27 +210,25 @@ class LibSignalChatConnectionTest { } } + // Test keepAlive failures: we get 4xx or 5xx, which triggers healthMonitor.onMessageError(...) but not onKeepAliveResponse. @Test fun keepAliveFailure() { for (response in listOf(RESPONSE_ERROR, RESPONSE_SERVER_ERROR)) { - val latch = CountDownLatch(1) + clearMocks(healthMonitor) - every { chatService.sendAndDebug(any()) } answers { + every { chatConnection.send(any()) } answers { delay { - it.complete(make_debug_response(response)) + it.complete(response) + sendLatch?.countDown() } } - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - } - } + setupConnectedConnection() - connection.connect() + sendLatch = CountDownLatch(1) connection.sendKeepAlive() - latch.await(100, TimeUnit.MILLISECONDS) + sendLatch!!.await(100, TimeUnit.MILLISECONDS) verify(exactly = 1) { healthMonitor.onMessageError(response.status, false) @@ -244,31 +239,22 @@ class LibSignalChatConnectionTest { } } + // Test keepAlive that fails at the transport layer (send() throws), + // which transitions from CONNECTED -> DISCONNECTED. @Test fun keepAliveConnectionFailure() { val connectionFailure = RuntimeException("Sending keep-alive failed") - val connectLatch = CountDownLatch(1) val keepAliveFailureLatch = CountDownLatch(1) - every { - chatService.sendAndDebug(any()) - } answers { + every { chatConnection.send(any()) } answers { delay { it.completeExceptionally(connectionFailure) keepAliveFailureLatch.countDown() } } - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - - connection.connect() - connectLatch.await(100, TimeUnit.MILLISECONDS) + setupConnectedConnection() val observer = TestObserver() connection.state.subscribe(observer) @@ -290,58 +276,17 @@ class LibSignalChatConnectionTest { } } + // Test that an incoming "connection interrupted" event from ChatConnection sets our state to DISCONNECTED. @Test fun connectionInterruptedTest() { val disconnectReason = ChatServiceException("simulated interrupt") - val connectLatch = CountDownLatch(1) - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - - connection.connect() - connectLatch.await(100, TimeUnit.MILLISECONDS) + setupConnectedConnection() val observer = TestObserver() connection.state.subscribe(observer) - chatListener!!.onConnectionInterrupted(chatService, disconnectReason) - - observer.assertNotComplete() - observer.assertValues( - // We start in the connected state - WebSocketConnectionState.CONNECTED, - // Disconnects as a result of the connection interrupted event - WebSocketConnectionState.DISCONNECTED - ) - verify(exactly = 0) { - healthMonitor.onKeepAliveResponse(any(), any()) - healthMonitor.onMessageError(any(), any()) - } - } - - @Test - fun connectionInterrupted() { - val disconnectReason = ChatServiceException("simulated interrupt") - val connectLatch = CountDownLatch(1) - - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - - connection.connect() - connectLatch.await(100, TimeUnit.MILLISECONDS) - - val observer = TestObserver() - connection.state.subscribe(observer) - - chatListener!!.onConnectionInterrupted(chatService, disconnectReason) + chatListener!!.onConnectionInterrupted(chatConnection, disconnectReason) observer.assertNotComplete() observer.assertValues( @@ -356,36 +301,32 @@ class LibSignalChatConnectionTest { } } + // Test reading incoming requests from the queue. + // We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection. @Test fun incomingRequests() { - val connectLatch = CountDownLatch(1) - val asyncMessageReadLatch = CountDownLatch(1) - - every { chatService.connect() } answers { - delay { - it.complete(DEBUG_INFO) - connectLatch.countDown() - } - } - - connection.connect() - connectLatch.await(100, TimeUnit.MILLISECONDS) + setupConnectedConnection() val observer = TestObserver() connection.state.subscribe(observer) + // Confirm that readRequest times out if there's no message. var timedOut = false try { connection.readRequest(10) } catch (e: TimeoutException) { timedOut = true } - assert(timedOut) + assertTrue(timedOut) + // We'll now simulate incoming messages val envelopeA = "msgA".toByteArray() val envelopeB = "msgB".toByteArray() val envelopeC = "msgC".toByteArray() + val asyncMessageReadLatch = CountDownLatch(1) + + // Helper to check that the WebSocketRequestMessage for an envelope is as expected fun assertRequestWithEnvelope(request: WebSocketRequestMessage, envelope: ByteArray) { assertEquals("PUT", request.verb) assertEquals("/api/v1/message", request.path) @@ -399,6 +340,7 @@ class LibSignalChatConnectionTest { ) } + // Helper to check that a queue-empty request is as expected fun assertQueueEmptyRequest(request: WebSocketRequestMessage) { assertEquals("PUT", request.verb) assertEquals("/api/v1/queue/empty", request.path) @@ -411,20 +353,23 @@ class LibSignalChatConnectionTest { ) } + // Read request asynchronously to simulate concurrency executor.submit { - assertRequestWithEnvelope(connection.readRequest(10), envelopeA) + val request = connection.readRequest(200) + assertRequestWithEnvelope(request, envelopeA) asyncMessageReadLatch.countDown() } - chatListener!!.onIncomingMessage(chatService, envelopeA, 0, null) + + chatListener!!.onIncomingMessage(chatConnection, envelopeA, 0, null) asyncMessageReadLatch.await(100, TimeUnit.MILLISECONDS) - chatListener!!.onIncomingMessage(chatService, envelopeB, 0, null) + chatListener!!.onIncomingMessage(chatConnection, envelopeB, 0, null) assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeB) - chatListener!!.onQueueEmpty(chatService) + chatListener!!.onQueueEmpty(chatConnection) assertQueueEmptyRequest(connection.readRequestIfAvailable().get()) - chatListener!!.onIncomingMessage(chatService, envelopeC, 0, null) + chatListener!!.onIncomingMessage(chatConnection, envelopeC, 0, null) assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeC) assertTrue(connection.readRequestIfAvailable().isEmpty) @@ -439,13 +384,9 @@ class LibSignalChatConnectionTest { } companion object { - private val DEBUG_INFO: DebugInfo = DebugInfo(IpType.UNKNOWN, 100, "") - private val RESPONSE_SUCCESS = LibSignalResponse(200, "", emptyMap(), byteArrayOf()) - private val RESPONSE_ERROR = LibSignalResponse(400, "", emptyMap(), byteArrayOf()) - private val RESPONSE_SERVER_ERROR = LibSignalResponse(500, "", emptyMap(), byteArrayOf()) - - private fun make_debug_response(response: LibSignalResponse): LibSignalDebugResponse { - return LibSignalDebugResponse(response, DEBUG_INFO) - } + // For verifying success / error scenarios in keepAlive tests, etc. + private val RESPONSE_SUCCESS = ChatConnection.Response(200, "", emptyMap(), byteArrayOf()) + private val RESPONSE_ERROR = ChatConnection.Response(400, "", emptyMap(), byteArrayOf()) + private val RESPONSE_SERVER_ERROR = ChatConnection.Response(500, "", emptyMap(), byteArrayOf()) } }