diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index a3fee076ad..c6fc372da2 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -29,6 +29,7 @@ import java.io.IOException import java.time.Instant import java.util.Optional import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException @@ -75,8 +76,16 @@ class LibSignalChatConnection( private val nextIncomingMessageInternalPseudoId = AtomicLong(1) val ackSenderForInternalPseudoId = ConcurrentHashMap() + // CHAT_SERVICE_LOCK: Protects state, stateChangedOrMessageReceivedCondition, chatConnection, and + // chatConnectionFuture + // stateChangedOrMessageReceivedCondition: derived from CHAT_SERVICE_LOCK, used by readRequest(), + // exists to emulate idiosyncratic behavior of OkHttpWebSocketConnection for readRequest() + // chatConnection: Set only when state == CONNECTED + // chatConnectionFuture: Set only when state == CONNECTING private val CHAT_SERVICE_LOCK = ReentrantLock() + private val stateChangedOrMessageReceivedCondition = CHAT_SERVICE_LOCK.newCondition() private var chatConnection: ChatConnection? = null + private var chatConnectionFuture: CompletableFuture? = null companion object { const val SERVICE_ENVELOPE_REQUEST_VERB = "PUT" @@ -126,10 +135,14 @@ class LibSignalChatConnection( val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) - val cleanupMonitor = state.subscribe { nextState -> + val stateMonitor = state.subscribe { nextState -> if (nextState == WebSocketConnectionState.DISCONNECTED) { cleanup() } + + CHAT_SERVICE_LOCK.withLock { + stateChangedOrMessageReceivedCondition.signalAll() + } } private fun cleanup() { @@ -155,13 +168,15 @@ class LibSignalChatConnection( return state } Log.i(TAG, "$name Connecting...") - val chatConnectionFuture: CompletableFuture = if (credentialsProvider == null) { + chatConnectionFuture = if (credentialsProvider == null) { network.connectUnauthChat(listener) } else { network.connectAuthChat(credentialsProvider.username, credentialsProvider.password, receiveStories, listener) } state.onNext(WebSocketConnectionState.CONNECTING) - chatConnectionFuture.whenComplete( + // We are now in the CONNECTING state, so chatConnectionFuture should be set, and there is no + // nullability concern here. + chatConnectionFuture!!.whenComplete( onSuccess = { connection -> CHAT_SERVICE_LOCK.withLock { if (state.value == WebSocketConnectionState.CONNECTING) { @@ -218,13 +233,14 @@ class LibSignalChatConnection( return } - // This avoids a crash when we get a connection lost event during a connection attempt and try - // to cancel a connection that has not yet been fully established. - // TODO [andrew]: Figure out if this is the right long term behavior. + // OkHttpWebSocketConnection will terminate a connection if disconnect() is called while + // the connection itself is still CONNECTING, so we carry forward that behavior here. if (state.value == WebSocketConnectionState.CONNECTING) { - // The right way to do this is to cancel the CompletableFuture returned by connectChat() + // The right way to do this is to cancel the CompletableFuture returned by connectChat(). + // This will terminate forward progress on the connection attempt, and mostly closely match + // what OkHttpWebSocketConnection does. // Unfortunately, libsignal's CompletableFuture does not yet support cancellation. - // Instead, we set a flag to disconnect() as soon as the connection completes. + // So, instead, we set a flag to disconnect() as soon as the connection completes. // TODO [andrew]: Add cancellation support to CompletableFuture and use it here state.onNext(WebSocketConnectionState.DISCONNECTING) return @@ -289,11 +305,33 @@ class LibSignalChatConnection( override fun sendKeepAlive() { CHAT_SERVICE_LOCK.withLock { - // This is a stronger check than isDead, to handle the case where chatConnection may be null - // because we are still connecting. - // TODO [andrew]: Decide if this is the right behavior long term, or if we want to queue these - // like we plan to queue other requests long term. - if (state.value != WebSocketConnectionState.CONNECTED) { + if (isDead()) { + // This matches the behavior of OkHttpWebSocketConnection, where if a keep alive is sent + // while we are not connected, we simply drop the keep alive. + return + } + + if (state.value == WebSocketConnectionState.CONNECTING) { + // Handle the special case where we are connecting, so we cannot (yet) send the keep-alive. + // OkHttpWebSocketConnection buffers the keep alive request, and sends it when the connection + // completes. + // We just checked that we are in the CONNECTING state, and we hold the CHAT_SERVICE_LOCK, so + // our state cannot change, thus there is no nullability concern with chatConnectionFuture. + Log.i(TAG, "$name Buffering keep alive to send after connection establishment") + chatConnectionFuture!!.whenComplete( + onSuccess = { + Log.i(TAG, "$name Sending buffered keep alive") + // sendKeepAlive() will internally grab the CHAT_SERVICE_LOCK and check to ensure we are + // still in the CONNECTED state when this callback runs, so we do not need to worry about + // any state here. + sendKeepAlive() + }, + onFailure = { + // OkHttpWebSocketConnection did not report a keep alive failure to the healthMonitor + // when a buffered keep alive failed to send because the underlying connection + // establishment failed, so neither do we. + } + ) return } @@ -332,22 +370,76 @@ class LibSignalChatConnection( return Optional.ofNullable(incomingMessage) } + /** + * Blocks until a request is received from the underlying ChatConnection. + * + * This method’s behavior is critical for message retrieval and must adhere to the following: + * + * - Blocks until a request is available. + * - If no message is received within the specified [timeoutMillis], a [TimeoutException] is thrown. + * - If the ChatConnection becomes disconnected while waiting, an [IOException] is thrown immediately. + * - If invoked when the ChatConnection is dead (i.e. disconnected or failed), an [IOException] is thrown. + * - If the ChatConnection is still in the process of connecting, the method will block until the connection + * is established and a message is received. The time spent waiting for the connection is counted towards + * the [timeoutMillis]. Should the connection attempt eventually fail, an [IOException] is thrown promptly. + * + * **Note:** This method is used by the MessageRetrievalThread to receive updates about the connection state + * from other threads. Any delay in throwing exceptions could block this thread, resulting in prolonged holding + * of the Foreground Service and wake lock, which may lead to adverse behavior by the operating system. + * + * @param timeoutMillis the maximum time in milliseconds to wait for a request. + * @return the received [WebSocketRequestMessage]. + * @throws TimeoutException if the timeout elapses without receiving a message. + * @throws IOException if the ChatConnection becomes disconnected, is dead, or if the connection attempt fails. + */ override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage { - return readRequestInternal(timeoutMillis, timeoutMillis) - } - - private fun readRequestInternal(timeoutMillis: Long, originalTimeoutMillis: Long): WebSocketRequestMessage { - if (timeoutMillis < 0) { - throw TimeoutException("No message available after $originalTimeoutMillis ms") + if (timeoutMillis <= 0) { + // OkHttpWebSocketConnection throws a TimeoutException in this case, so we do too. + throw TimeoutException("Invalid timeoutMillis") } val startTime = System.currentTimeMillis() - try { - return incomingRequestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS) ?: throw TimeoutException("No message available after $originalTimeoutMillis ms") - } catch (e: InterruptedException) { - val elapsedTimeMillis = System.currentTimeMillis() - startTime - val timeoutRemainingMillis = timeoutMillis - elapsedTimeMillis - return readRequestInternal(timeoutRemainingMillis, originalTimeoutMillis) + + CHAT_SERVICE_LOCK.withLock { + if (isDead()) { + // Matches behavior of OkHttpWebSocketConnection + throw IOException("Connection closed!") + } + + var remainingTimeoutMillis = timeoutMillis + + fun couldGetRequest(): Boolean { + return state.value == WebSocketConnectionState.CONNECTED || state.value == WebSocketConnectionState.CONNECTING + } + + while (couldGetRequest() && incomingRequestQueue.isEmpty()) { + if (remainingTimeoutMillis <= 0) { + throw TimeoutException("Timeout exceeded after $timeoutMillis ms") + } + + try { + // This condition variable is created from CHAT_SERVICE_LOCK, and thus releases CHAT_SERVICE_LOCK + // while we await the condition variable. + stateChangedOrMessageReceivedCondition.await(remainingTimeoutMillis, TimeUnit.MILLISECONDS) + } catch (_: InterruptedException) { } + val elapsedTimeMillis = System.currentTimeMillis() - startTime + remainingTimeoutMillis = timeoutMillis - elapsedTimeMillis + } + + if (!incomingRequestQueue.isEmpty()) { + return incomingRequestQueue.poll() + } else if (!couldGetRequest()) { + throw IOException("Connection closed!") + } else { + // This happens if we somehow break out of the loop but incomingRequestQueue is empty + // and we were still in a state where we could get a request. + // This *could* theoretically happen if two different threads call readRequest at the same time, + // this thread is the one that loses the race to take the request off the queue. + // (NB: I don't think this is a practical issue, because readRequest() should only be called from + // the MessageRetrievalThread, but OkHttpWebSocketConnection treated this as a TimeoutException, so + // this class also dutifully treats it as a TimeoutException.) + throw TimeoutException("Incoming request queue was empty!") + } } } @@ -366,6 +458,8 @@ class LibSignalChatConnection( private val listener = LibSignalChatListener() private inner class LibSignalChatListener : ChatConnectionListener { + private val executor = Executors.newSingleThreadExecutor() + override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) { // NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is // already in the ackSender map, if it exists. @@ -381,6 +475,12 @@ class LibSignalChatConnection( ackSenderForInternalPseudoId[internalPseudoId] = sendAck } incomingRequestQueue.put(incomingWebSocketRequest) + // Try to not block the ChatConnectionListener callback context if we can help it. + executor.submit { + CHAT_SERVICE_LOCK.withLock { + stateChangedOrMessageReceivedCondition.signalAll() + } + } } override fun onConnectionInterrupted(chat: ChatConnection, disconnectReason: ChatServiceException?) { @@ -402,6 +502,12 @@ class LibSignalChatConnection( id = internalPseudoId ) incomingRequestQueue.put(queueEmptyRequest) + // Try to not block the ChatConnectionListener callback context if we can help it. + executor.submit { + CHAT_SERVICE_LOCK.withLock { + stateChangedOrMessageReceivedCondition.signalAll() + } + } } } } diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index 9b559bace7..a6c7c6cee1 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -8,6 +8,7 @@ import io.mockk.verify import io.reactivex.rxjava3.observers.TestObserver import okio.ByteString.Companion.toByteString import org.junit.Assert.assertEquals +import org.junit.Assert.assertThrows import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test @@ -19,6 +20,7 @@ import org.signal.libsignal.net.Network import org.signal.libsignal.net.UnauthenticatedChatConnection import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState +import java.io.IOException import java.util.concurrent.CountDownLatch import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -307,6 +309,34 @@ class LibSignalChatConnectionTest { } } + // If readRequest() does not throw when the underlying connection disconnects, this + // causes the app to get stuck in a "fetching new messages" state. + @Test + fun regressionTestReadRequestThrowsOnDisconnect() { + setupConnectedConnection() + + executor.submit { + Thread.sleep(100) + chatConnection.disconnect() + } + + assertThrows(IOException::class.java) { + connection.readRequest(1000) + } + } + + @Test(timeout = 20) + fun readRequestDoesTimeOut() { + setupConnectedConnection() + + val observer = TestObserver() + connection.state.subscribe(observer) + + assertThrows(TimeoutException::class.java) { + connection.readRequest(10) + } + } + // Test reading incoming requests from the queue. // We'll simulate onIncomingMessage() from the ChatConnectionListener, then read them from the LibSignalChatConnection. @Test @@ -316,15 +346,6 @@ class LibSignalChatConnectionTest { val observer = TestObserver() connection.state.subscribe(observer) - // Confirm that readRequest times out if there's no message. - var timedOut = false - try { - connection.readRequest(10) - } catch (e: TimeoutException) { - timedOut = true - } - assertTrue(timedOut) - // We'll now simulate incoming messages val envelopeA = "msgA".toByteArray() val envelopeB = "msgB".toByteArray()