From 080b79c893ae59471a032083d76bdfa9cfcaf4e3 Mon Sep 17 00:00:00 2001 From: andrew-signal Date: Thu, 5 Dec 2024 09:27:31 -0500 Subject: [PATCH] Use LibSignalChatConnection for Authenticated Socket based on Remote Config --- .../ApplicationDependencyProvider.java | 23 ++- .../api/websocket/HealthMonitor.java | 10 -- .../api/websocket/HealthMonitor.kt | 10 ++ .../websocket/LibSignalChatConnection.kt | 166 +++++++++++++++--- .../internal/websocket/WebSocketConnection.kt | 2 +- .../websocket/LibSignalChatConnectionTest.kt | 111 ++++++++++++ 6 files changed, 278 insertions(+), 44 deletions(-) delete mode 100644 libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.java create mode 100644 libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.kt diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index cabd791375..cfeaf80661 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -408,12 +408,23 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { return new WebSocketFactory() { @Override public WebSocketConnection createWebSocket() { - return new OkHttpWebSocketConnection("normal", - signalServiceConfigurationSupplier.get(), - Optional.of(new DynamicCredentialsProvider()), - BuildConfig.SIGNAL_AGENT, - healthMonitor, - Stories.isFeatureEnabled()); + if (RemoteConfig.libSignalWebSocketEnabled()) { + Network network = libSignalNetworkSupplier.get(); + return new LibSignalChatConnection( + "libsignal-auth", + network, + new DynamicCredentialsProvider(), + Stories.isFeatureEnabled(), + healthMonitor + ); + } else { + return new OkHttpWebSocketConnection("normal", + signalServiceConfigurationSupplier.get(), + Optional.of(new DynamicCredentialsProvider()), + BuildConfig.SIGNAL_AGENT, + healthMonitor, + Stories.isFeatureEnabled()); + } } @Override diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.java deleted file mode 100644 index 692780bb27..0000000000 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.java +++ /dev/null @@ -1,10 +0,0 @@ -package org.whispersystems.signalservice.api.websocket; - -/** - * Callbacks to provide WebSocket health information to a monitor. - */ -public interface HealthMonitor { - void onKeepAliveResponse(long sentTimestamp, boolean isIdentifiedWebSocket); - - void onMessageError(int status, boolean isIdentifiedWebSocket); -} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.kt new file mode 100644 index 0000000000..2135d4cb83 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/HealthMonitor.kt @@ -0,0 +1,10 @@ +package org.whispersystems.signalservice.api.websocket + +/** + * Callbacks to provide WebSocket health information to a monitor. + */ +interface HealthMonitor { + fun onKeepAliveResponse(sentTimestamp: Long, isIdentifiedWebSocket: Boolean) + + fun onMessageError(status: Int, isIdentifiedWebSocket: Boolean) +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 181b7c8d6e..588f4fe9fd 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -10,11 +10,14 @@ import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.subjects.BehaviorSubject import io.reactivex.rxjava3.subjects.SingleSubject +import okio.ByteString +import okio.ByteString.Companion.toByteString import org.signal.core.util.logging.Log import org.signal.libsignal.net.AuthenticatedChatService import org.signal.libsignal.net.ChatListener import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.ChatServiceException +import org.signal.libsignal.net.DeviceDeregisteredException import org.signal.libsignal.net.Network import org.signal.libsignal.net.UnauthenticatedChatService import org.whispersystems.signalservice.api.util.CredentialsProvider @@ -24,6 +27,11 @@ import org.whispersystems.signalservice.internal.util.whenComplete import java.io.IOException import java.time.Instant import java.util.Optional +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock import kotlin.time.Duration.Companion.seconds @@ -50,11 +58,33 @@ class LibSignalChatConnection( private val receiveStories: Boolean, private val healthMonitor: HealthMonitor ) : WebSocketConnection { + private val incomingRequestQueue = LinkedBlockingQueue() + + // One of the more nasty parts of this is that libsignal-net does not expose, nor does it ever + // intend to expose, the ID of the incoming "request" to the app layer. Instead, the app layer + // is given a callback for each message it should call when it wants to ack that message. + // The layer above this, SignalWebSocket.java, is written to handle HTTP Requests and Responses + // that have responses embedded within them. + // The goal of this stage of the project is to try and change as little as possible to isolate + // any bugs with underlying libsignal-net layer. + // So, we lie. + // We assign our own "pseudo IDs" for each incoming request in this layer, provide that ID + // up the stack to the SignalWebSocket, and then we store it. Eventually, SignalWebSocket will + // tell us to send a response for that ID, and then we use the pseudo ID as a handle to find + // the callback given to us earlier by libsignal-net, and we call that callback. + private val nextIncomingMessageInternalPseudoId = AtomicLong(1) + val ackSenderForInternalPseudoId = ConcurrentHashMap() private val CHAT_SERVICE_LOCK = ReentrantLock() private var chatService: ChatService? = null companion object { + const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT" + const val SERVICE_ENVELOPE_REQUEST_PATH = "/api/v1/message" + const val SOCKET_EMPTY_REQUEST_VERB = "PUT" + const val SOCKET_EMPTY_REQUEST_PATH = "/api/v1/queue/empty" + const val SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY = "X-Signal-Timestamp" + private val TAG = Log.tag(LibSignalChatConnection::class.java) private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds @@ -96,6 +126,22 @@ class LibSignalChatConnection( val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) + val cleanupMonitor = state.subscribe { nextState -> + if (nextState == WebSocketConnectionState.DISCONNECTED) { + cleanup() + } + } + + private fun cleanup() { + Log.i(TAG, "$name [cleanup]") + incomingRequestQueue.clear() + // There's a race condition here where someone has a request with an ack outstanding + // when we clear the ackSender table, but it's benign because we handle the case where + // there is no ackSender for a pseudoId gracefully in sendResponse. + ackSenderForInternalPseudoId.clear() + // There's no sense in resetting nextIncomingMessageInternalPseudoId. + } + override fun connect(): Observable { CHAT_SERVICE_LOCK.withLock { if (chatService != null) { @@ -112,9 +158,16 @@ class LibSignalChatConnection( state.onNext(WebSocketConnectionState.CONNECTED) }, onFailure = { throwable -> - // TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors - Log.w(TAG, "$name Connect failed", throwable) - state.onNext(WebSocketConnectionState.FAILED) + Log.w(TAG, "$name [connect] Failure:", throwable) + // Internally, libsignal-net will throw this DeviceDeregisteredException when the HTTP CONNECT + // request returns HTTP 403. + // The chat service currently does not return HTTP 401 on /v1/websocket. + // Thus, this currently matches the implementation in OkHttpWebSocketConnection. + if (throwable is DeviceDeregisteredException) { + state.onNext(WebSocketConnectionState.AUTHENTICATION_FAILED) + } else { + state.onNext(WebSocketConnectionState.FAILED) + } } ) } @@ -122,7 +175,11 @@ class LibSignalChatConnection( } } - override fun isDead(): Boolean = false + override fun isDead(): Boolean { + CHAT_SERVICE_LOCK.withLock { + return chatService == null + } + } override fun disconnect() { CHAT_SERVICE_LOCK.withLock { @@ -150,16 +207,20 @@ class LibSignalChatConnection( override fun sendRequest(request: WebSocketRequestMessage): Single { CHAT_SERVICE_LOCK.withLock { if (chatService == null) { - return Single.error(IOException("[$name] is closed!")) + return Single.error(IOException("$name is closed!")) } val single = SingleSubject.create() val internalRequest = request.toLibSignalRequest() chatService!!.send(internalRequest) .whenComplete( onSuccess = { response -> - when (response!!.status) { + Log.d(TAG, "$name [sendRequest] Success: ${response!!.status}") + when (response.status) { in 400..599 -> { - healthMonitor.onMessageError(response.status, false) + healthMonitor.onMessageError( + status = response.status, + isIdentifiedWebSocket = chatService is AuthenticatedChatService + ) } } // Here success means "we received the response" even if it is reporting an error. @@ -167,7 +228,7 @@ class LibSignalChatConnection( single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) }, onFailure = { throwable -> - Log.w(TAG, "$name sendRequest failed", throwable) + Log.w(TAG, "$name [sendRequest] Failure:", throwable) single.onError(throwable) } ) @@ -185,12 +246,12 @@ class LibSignalChatConnection( chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST) .whenComplete( onSuccess = { debugResponse -> - Log.d(TAG, "$name Keep alive - success") + Log.d(TAG, "$name [sendKeepAlive] Success") when (debugResponse!!.response.status) { in 200..299 -> { healthMonitor.onKeepAliveResponse( - Instant.now().toEpochMilli(), // ignored. can be any value - false + sentTimestamp = Instant.now().toEpochMilli(), // ignored. can be any value + isIdentifiedWebSocket = chatService is AuthenticatedChatService ) } @@ -199,12 +260,12 @@ class LibSignalChatConnection( } else -> { - Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") + Log.w(TAG, "$name [sendKeepAlive] Unsupported keep alive response status: ${debugResponse.response.status}") } } }, onFailure = { throwable -> - Log.w(TAG, "$name Keep alive - failed", throwable) + Log.w(TAG, "$name [sendKeepAlive] Failure:", throwable) state.onNext(WebSocketConnectionState.DISCONNECTED) } ) @@ -212,28 +273,79 @@ class LibSignalChatConnection( } override fun readRequestIfAvailable(): Optional { - throw NotImplementedError() + val incomingMessage = incomingRequestQueue.poll() + return Optional.ofNullable(incomingMessage) } override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage { - throw NotImplementedError() + return readRequestInternal(timeoutMillis, timeoutMillis) } - override fun sendResponse(response: WebSocketResponseMessage?) { - throw NotImplementedError() - } - - private val listener = object : ChatListener { - override fun onIncomingMessage(chat: ChatService?, envelope: ByteArray?, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) { - throw NotImplementedError() + private fun readRequestInternal(timeoutMillis: Long, originalTimeoutMillis: Long): WebSocketRequestMessage { + if (timeoutMillis < 0) { + throw TimeoutException("No message available after $originalTimeoutMillis ms") } - override fun onConnectionInterrupted(chat: ChatService?, disconnectReason: ChatServiceException?) { - CHAT_SERVICE_LOCK.withLock { - Log.i(TAG, "connection interrupted", disconnectReason) - state.onNext(WebSocketConnectionState.DISCONNECTED) - chatService = null + val startTime = System.currentTimeMillis() + try { + return incomingRequestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No message available after $originalTimeoutMillis ms") + } catch (e: InterruptedException) { + val elapsedTimeMillis = System.currentTimeMillis() - startTime + val timeoutRemainingMillis = timeoutMillis - elapsedTimeMillis + return readRequestInternal(timeoutRemainingMillis, originalTimeoutMillis) + } + } + + override fun sendResponse(response: WebSocketResponseMessage) { + if (response.status == 200 && response.message.equals("OK")) { + ackSenderForInternalPseudoId[response.id]?.send() ?: Log.w(TAG, "$name [sendResponse] Silently dropped response without available ackSend {id: ${response.id}}") + ackSenderForInternalPseudoId.remove(response.id) + Log.d(TAG, "$name [sendResponse] sent ack [${response.id}]") + } else { + // libsignal-net only supports sending {200: OK} responses + Log.w(TAG, "$name [sendResponse] Silently dropped unsupported response {status: ${response.status}, id: ${response.id}}") + ackSenderForInternalPseudoId.remove(response.id) + } + } + + private val listener = LibSignalChatListener() + + private inner class LibSignalChatListener : ChatListener { + override fun onIncomingMessage(chat: ChatService, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatListener.ServerMessageAck?) { + // NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is + // already in the ackSender map, if it exists. + val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement() + val incomingWebSocketRequest = WebSocketRequestMessage( + verb = SERVICE_ENVELOPE_REQUEST_VERB, + path = SERVICE_ENVELOPE_REQUEST_PATH, + body = envelope.toByteString(), + headers = listOf("$SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY: $serverDeliveryTimestamp"), + id = internalPseudoId + ) + if (sendAck != null) { + ackSenderForInternalPseudoId[internalPseudoId] = sendAck } + incomingRequestQueue.put(incomingWebSocketRequest) + } + + override fun onConnectionInterrupted(chat: ChatService, disconnectReason: ChatServiceException) { + CHAT_SERVICE_LOCK.withLock { + Log.i(TAG, "$name connection interrupted", disconnectReason) + chatService = null + state.onNext(WebSocketConnectionState.DISCONNECTED) + } + } + + override fun onQueueEmpty(chat: ChatService) { + val internalPseudoId = nextIncomingMessageInternalPseudoId.getAndIncrement() + val queueEmptyRequest = WebSocketRequestMessage( + verb = SOCKET_EMPTY_REQUEST_VERB, + path = SOCKET_EMPTY_REQUEST_PATH, + body = ByteString.EMPTY, + headers = listOf(), + id = internalPseudoId + ) + incomingRequestQueue.put(queueEmptyRequest) } } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt index d6e0e1197f..259f47862d 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt @@ -35,5 +35,5 @@ interface WebSocketConnection { fun readRequest(timeoutMillis: Long): WebSocketRequestMessage @Throws(IOException::class) - fun sendResponse(response: WebSocketResponseMessage?) + fun sendResponse(response: WebSocketResponseMessage) } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index 12f0971860..db67a8d0ab 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -6,6 +6,9 @@ import io.mockk.mockk import io.mockk.mockkStatic import io.mockk.verify import io.reactivex.rxjava3.observers.TestObserver +import okio.ByteString.Companion.toByteString +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test import org.signal.libsignal.internal.CompletableFuture @@ -21,6 +24,7 @@ import java.util.concurrent.CountDownLatch import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException import org.signal.libsignal.net.ChatService.Response as LibSignalResponse import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse @@ -286,6 +290,39 @@ class LibSignalChatConnectionTest { } } + @Test + fun connectionInterruptedTest() { + val disconnectReason = ChatServiceException("simulated interrupt") + val connectLatch = CountDownLatch(1) + + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + connectLatch.countDown() + } + } + + connection.connect() + connectLatch.await(100, TimeUnit.MILLISECONDS) + + val observer = TestObserver() + connection.state.subscribe(observer) + + chatListener!!.onConnectionInterrupted(chatService, disconnectReason) + + observer.assertNotComplete() + observer.assertValues( + // We start in the connected state + WebSocketConnectionState.CONNECTED, + // Disconnects as a result of the connection interrupted event + WebSocketConnectionState.DISCONNECTED + ) + verify(exactly = 0) { + healthMonitor.onKeepAliveResponse(any(), any()) + healthMonitor.onMessageError(any(), any()) + } + } + @Test fun connectionInterrupted() { val disconnectReason = ChatServiceException("simulated interrupt") @@ -319,6 +356,80 @@ class LibSignalChatConnectionTest { } } + @Test + fun incomingRequests() { + val connectLatch = CountDownLatch(1) + val asyncMessageReadLatch = CountDownLatch(1) + + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + connectLatch.countDown() + } + } + + connection.connect() + connectLatch.await(100, TimeUnit.MILLISECONDS) + + val observer = TestObserver() + connection.state.subscribe(observer) + + var timedOut = false + try { + connection.readRequest(10) + } catch (e: TimeoutException) { + timedOut = true + } + assert(timedOut) + + val envelopeA = "msgA".toByteArray() + val envelopeB = "msgB".toByteArray() + val envelopeC = "msgC".toByteArray() + + fun assertRequestWithEnvelope(request: WebSocketRequestMessage, envelope: ByteArray) { + assertEquals("PUT", request.verb) + assertEquals("/api/v1/message", request.path) + assertEquals(envelope.toByteString(), request.body!!) + connection.sendResponse( + WebSocketResponseMessage( + request.id, + 200, + "OK" + ) + ) + } + + fun assertQueueEmptyRequest(request: WebSocketRequestMessage) { + assertEquals("PUT", request.verb) + assertEquals("/api/v1/queue/empty", request.path) + connection.sendResponse( + WebSocketResponseMessage( + request.id, + 200, + "OK" + ) + ) + } + + executor.submit { + assertRequestWithEnvelope(connection.readRequest(10), envelopeA) + asyncMessageReadLatch.countDown() + } + chatListener!!.onIncomingMessage(chatService, envelopeA, 0, null) + asyncMessageReadLatch.await(100, TimeUnit.MILLISECONDS) + + chatListener!!.onIncomingMessage(chatService, envelopeB, 0, null) + assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeB) + + chatListener!!.onQueueEmpty(chatService) + assertQueueEmptyRequest(connection.readRequestIfAvailable().get()) + + chatListener!!.onIncomingMessage(chatService, envelopeC, 0, null) + assertRequestWithEnvelope(connection.readRequestIfAvailable().get(), envelopeC) + + assertTrue(connection.readRequestIfAvailable().isEmpty) + } + private fun delay(action: ((CompletableFuture) -> Unit)): CompletableFuture { val future = CompletableFuture() executor.submit {