diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index c7588711ad..4ab7193908 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -434,7 +434,9 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { Network network = libSignalNetworkSupplier.get(); return new LibSignalChatConnection( "libsignal-unauth", - LibSignalNetworkExtensions.createChatService(network, null, Stories.isFeatureEnabled()), + network, + null, + Stories.isFeatureEnabled(), healthMonitor); } else { return new OkHttpWebSocketConnection("unidentified", diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 14d4ce76c7..00f7f19aea 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -13,12 +13,17 @@ import io.reactivex.rxjava3.subjects.SingleSubject import org.signal.core.util.logging.Log import org.signal.libsignal.net.AuthenticatedChatService import org.signal.libsignal.net.ChatService +import org.signal.libsignal.net.Network import org.signal.libsignal.net.UnauthenticatedChatService +import org.whispersystems.signalservice.api.util.CredentialsProvider import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.internal.util.whenComplete +import java.io.IOException import java.time.Instant import java.util.Optional +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock import kotlin.time.Duration.Companion.seconds import org.signal.libsignal.net.ChatService.Request as LibSignalRequest import org.signal.libsignal.net.ChatService.Response as LibSignalResponse @@ -38,10 +43,15 @@ import org.signal.libsignal.net.ChatService.Response as LibSignalResponse */ class LibSignalChatConnection( name: String, - private val chatService: ChatService, + private val network: Network, + private val credentialsProvider: CredentialsProvider?, + private val receiveStories: Boolean, private val healthMonitor: HealthMonitor ) : WebSocketConnection { + private val CHAT_SERVICE_LOCK = ReentrantLock() + private var chatService: ChatService? = null + companion object { private val TAG = Log.tag(LibSignalChatConnection::class.java) private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds @@ -85,95 +95,118 @@ class LibSignalChatConnection( val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) override fun connect(): Observable { - Log.i(TAG, "$name Connecting...") - state.onNext(WebSocketConnectionState.CONNECTING) - chatService.connect() - .whenComplete( - onSuccess = { debugInfo -> - Log.i(TAG, "$name Connected") - Log.d(TAG, "$name $debugInfo") - state.onNext(WebSocketConnectionState.CONNECTED) - }, - onFailure = { throwable -> - // TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors - Log.d(TAG, "$name Connect failed", throwable) - state.onNext(WebSocketConnectionState.FAILED) - } - ) - return state + CHAT_SERVICE_LOCK.withLock { + if (chatService != null) { + return state + } + + Log.i(TAG, "$name Connecting...") + chatService = network.createChatService(credentialsProvider, receiveStories).apply { + state.onNext(WebSocketConnectionState.CONNECTING) + connect().whenComplete( + onSuccess = { debugInfo -> + Log.i(TAG, "$name Connected") + Log.d(TAG, "$name $debugInfo") + state.onNext(WebSocketConnectionState.CONNECTED) + }, + onFailure = { throwable -> + // TODO[libsignal-net]: Report AUTHENTICATION_FAILED for 401 and 403 errors + Log.w(TAG, "$name Connect failed", throwable) + state.onNext(WebSocketConnectionState.FAILED) + } + ) + } + return state + } } override fun isDead(): Boolean = false override fun disconnect() { - Log.i(TAG, "$name Disconnecting...") - state.onNext(WebSocketConnectionState.DISCONNECTING) - chatService.disconnect() - .whenComplete( - onSuccess = { - Log.i(TAG, "$name Disconnected") - state.onNext(WebSocketConnectionState.DISCONNECTED) - }, - onFailure = { throwable -> - Log.d(TAG, "$name Disconnect failed", throwable) - state.onNext(WebSocketConnectionState.DISCONNECTED) - } - ) + CHAT_SERVICE_LOCK.withLock { + if (chatService == null) { + return + } + + Log.i(TAG, "$name Disconnecting...") + state.onNext(WebSocketConnectionState.DISCONNECTING) + chatService!!.disconnect() + .whenComplete( + onSuccess = { + Log.i(TAG, "$name Disconnected") + state.onNext(WebSocketConnectionState.DISCONNECTED) + }, + onFailure = { throwable -> + Log.w(TAG, "$name Disconnect failed", throwable) + state.onNext(WebSocketConnectionState.DISCONNECTED) + } + ) + chatService = null + } } override fun sendRequest(request: WebSocketRequestMessage): Single { - val single = SingleSubject.create() - val internalRequest = request.toLibSignalRequest() - chatService.send(internalRequest) - .whenComplete( - onSuccess = { response -> - when (response!!.status) { - in 400..599 -> { - healthMonitor.onMessageError(response.status, false) + CHAT_SERVICE_LOCK.withLock { + if (chatService == null) { + return Single.error(IOException("[$name] is closed!")) + } + val single = SingleSubject.create() + val internalRequest = request.toLibSignalRequest() + chatService!!.send(internalRequest) + .whenComplete( + onSuccess = { response -> + when (response!!.status) { + in 400..599 -> { + healthMonitor.onMessageError(response.status, false) + } } + // Here success means "we received the response" even if it is reporting an error. + // This is consistent with the behavior of the OkHttpWebSocketConnection. + single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) + }, + onFailure = { throwable -> + Log.w(TAG, "$name sendRequest failed", throwable) + single.onError(throwable) } - // Here success means "we received the response" even if it is reporting an error. - // This is consistent with the behavior of the OkHttpWebSocketConnection. - single.onSuccess(response.toWebsocketResponse(isUnidentified = (chatService is UnauthenticatedChatService))) - }, - onFailure = { throwable -> - Log.i(TAG, "$name sendRequest failed", throwable) - single.onError(throwable) - } - ) - return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) + ) + return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) + } } override fun sendKeepAlive() { - Log.i(TAG, "$name Sending keep alive...") - chatService.sendAndDebug(KEEP_ALIVE_REQUEST) - .whenComplete( - onSuccess = { debugResponse -> - Log.i(TAG, "$name Keep alive - success") - Log.d(TAG, "$name $debugResponse") - when (debugResponse!!.response.status) { - in 200..299 -> { - healthMonitor.onKeepAliveResponse( - Instant.now().toEpochMilli(), // ignored. can be any value - false - ) - } + CHAT_SERVICE_LOCK.withLock { + if (chatService == null) { + return + } - in 400..599 -> { - healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService)) - } + Log.i(TAG, "$name Sending keep alive...") + chatService!!.sendAndDebug(KEEP_ALIVE_REQUEST) + .whenComplete( + onSuccess = { debugResponse -> + Log.d(TAG, "$name Keep alive - success") + when (debugResponse!!.response.status) { + in 200..299 -> { + healthMonitor.onKeepAliveResponse( + Instant.now().toEpochMilli(), // ignored. can be any value + false + ) + } - else -> { - Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") + in 400..599 -> { + healthMonitor.onMessageError(debugResponse.response.status, (chatService is AuthenticatedChatService)) + } + + else -> { + Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") + } } + }, + onFailure = { throwable -> + Log.w(TAG, "$name Keep alive - failed", throwable) + state.onNext(WebSocketConnectionState.DISCONNECTED) } - }, - onFailure = { throwable -> - Log.i(TAG, "$name Keep alive - failed") - Log.d(TAG, "$name $throwable") - state.onNext(WebSocketConnectionState.DISCONNECTED) - } - ) + ) + } } override fun readRequestIfAvailable(): Optional { diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt index ab3eb7b61a..6d24bc6558 100644 --- a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -3,6 +3,7 @@ package org.whispersystems.signalservice.internal.websocket import io.mockk.clearAllMocks import io.mockk.every import io.mockk.mockk +import io.mockk.mockkStatic import io.mockk.verify import io.reactivex.rxjava3.observers.TestObserver import org.junit.Before @@ -11,6 +12,7 @@ import org.signal.libsignal.internal.CompletableFuture import org.signal.libsignal.net.ChatService import org.signal.libsignal.net.ChatService.DebugInfo import org.signal.libsignal.net.IpType +import org.signal.libsignal.net.Network import org.whispersystems.signalservice.api.websocket.HealthMonitor import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import java.util.concurrent.CountDownLatch @@ -25,13 +27,16 @@ class LibSignalChatConnectionTest { private val executor: ExecutorService = Executors.newSingleThreadExecutor() private val healthMonitor = mockk() private val chatService = mockk() - private val connection = LibSignalChatConnection("test", chatService, healthMonitor) + private val network = mockk() + private val connection = LibSignalChatConnection("test", network, null, false, healthMonitor) @Before fun before() { clearAllMocks() + mockkStatic(Network::createChatService) every { healthMonitor.onMessageError(any(), any()) } every { healthMonitor.onKeepAliveResponse(any(), any()) } + every { network.createChatService(any(), any()) } answers { chatService } } @Test @@ -127,25 +132,37 @@ class LibSignalChatConnectionTest { fun orderOfStatesOnDisconnectFailure() { val disconnectException = RuntimeException("disconnect failed") - val latch = CountDownLatch(1) + val connectLatch = CountDownLatch(1) + val disconnectLatch = CountDownLatch(1) every { chatService.disconnect() } answers { delay { it.completeExceptionally(disconnectException) + disconnectLatch.countDown() } } - val observer = TestObserver() + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + connectLatch.countDown() + } + } + connection.connect() + + connectLatch.await(100, TimeUnit.MILLISECONDS) + + val observer = TestObserver() connection.state.subscribe(observer) connection.disconnect() - latch.await(100, TimeUnit.MILLISECONDS) + disconnectLatch.await(100, TimeUnit.MILLISECONDS) observer.assertNotComplete() observer.assertValues( - WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.CONNECTED, WebSocketConnectionState.DISCONNECTING, WebSocketConnectionState.DISCONNECTED ) @@ -162,6 +179,14 @@ class LibSignalChatConnectionTest { } } + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + } + } + + connection.connect() + connection.sendKeepAlive() latch.await(100, TimeUnit.MILLISECONDS) @@ -185,6 +210,14 @@ class LibSignalChatConnectionTest { } } + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + } + } + + connection.connect() + connection.sendKeepAlive() latch.await(100, TimeUnit.MILLISECONDS) @@ -200,28 +233,41 @@ class LibSignalChatConnectionTest { @Test fun keepAliveConnectionFailure() { val connectionFailure = RuntimeException("Sending keep-alive failed") - val latch = CountDownLatch(1) + + val connectLatch = CountDownLatch(1) + val keepAliveFailureLatch = CountDownLatch(1) every { chatService.sendAndDebug(any()) } answers { delay { it.completeExceptionally(connectionFailure) + keepAliveFailureLatch.countDown() } } + every { chatService.connect() } answers { + delay { + it.complete(DEBUG_INFO) + connectLatch.countDown() + } + } + + connection.connect() + connectLatch.await(100, TimeUnit.MILLISECONDS) + val observer = TestObserver() connection.state.subscribe(observer) connection.sendKeepAlive() - latch.await(100, TimeUnit.MILLISECONDS) + keepAliveFailureLatch.await(100, TimeUnit.MILLISECONDS) observer.assertNotComplete() observer.assertValues( - // This is the starting state - WebSocketConnectionState.DISCONNECTED, - // This one is the result of a keep-alive failure + // We start in the connected state + WebSocketConnectionState.CONNECTED, + // Disconnects as a result of keep-alive failure WebSocketConnectionState.DISCONNECTED ) verify(exactly = 0) {