diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencies.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencies.java index 362e910e0e..d067d7f088 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencies.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencies.java @@ -573,7 +573,7 @@ public class ApplicationDependencies { if (signalWebSocket == null) { synchronized (LOCK) { if (signalWebSocket == null) { - signalWebSocket = provider.provideSignalWebSocket(() -> getSignalServiceNetworkAccess().getConfiguration()); + signalWebSocket = provider.provideSignalWebSocket(() -> getSignalServiceNetworkAccess().getConfiguration(), ApplicationDependencies::getLibsignalNetwork); } } } @@ -726,7 +726,7 @@ public class ApplicationDependencies { @NonNull SignalCallManager provideSignalCallManager(); @NonNull PendingRetryReceiptManager providePendingRetryReceiptManager(); @NonNull PendingRetryReceiptCache providePendingRetryReceiptCache(); - @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier); + @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier); @NonNull SignalServiceDataStoreImpl provideProtocolStore(); @NonNull GiphyMp4Cache provideGiphyMp4Cache(); @NonNull SimpleExoPlayerPool provideExoPlayerPool(); diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index adfbf10e76..2bb2f43273 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -90,7 +90,10 @@ import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.UptimeSleepTimer; import org.whispersystems.signalservice.api.websocket.WebSocketFactory; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; +import org.whispersystems.signalservice.internal.websocket.LibSignalNetwork; import org.whispersystems.signalservice.internal.websocket.WebSocketConnection; +import org.whispersystems.signalservice.internal.websocket.LibSignalChatConnection; +import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection; import java.util.Optional; import java.util.concurrent.TimeUnit; @@ -287,10 +290,10 @@ public class ApplicationDependencyProvider implements ApplicationDependencies.Pr } @Override - public @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier) { + public @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internalValues().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer() ; SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(context, sleepTimer); - SignalWebSocket signalWebSocket = new SignalWebSocket(provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor)); + SignalWebSocket signalWebSocket = new SignalWebSocket(provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier)); healthMonitor.monitor(signalWebSocket); @@ -397,26 +400,35 @@ public class ApplicationDependencyProvider implements ApplicationDependencies.Pr return provideClientZkOperations(signalServiceConfiguration).getReceiptOperations(); } - @NonNull WebSocketFactory provideWebSocketFactory(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull SignalWebSocketHealthMonitor healthMonitor) { + @NonNull WebSocketFactory provideWebSocketFactory(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull SignalWebSocketHealthMonitor healthMonitor, @NonNull Supplier libSignalNetworkSupplier) { return new WebSocketFactory() { @Override public WebSocketConnection createWebSocket() { - return new WebSocketConnection("normal", - signalServiceConfigurationSupplier.get(), - Optional.of(new DynamicCredentialsProvider()), - BuildConfig.SIGNAL_AGENT, - healthMonitor, - Stories.isFeatureEnabled()); + return new OkHttpWebSocketConnection("normal", + signalServiceConfigurationSupplier.get(), + Optional.of(new DynamicCredentialsProvider()), + BuildConfig.SIGNAL_AGENT, + healthMonitor, + Stories.isFeatureEnabled()); } @Override public WebSocketConnection createUnidentifiedWebSocket() { - return new WebSocketConnection("unidentified", - signalServiceConfigurationSupplier.get(), - Optional.empty(), - BuildConfig.SIGNAL_AGENT, - healthMonitor, - Stories.isFeatureEnabled()); + if (FeatureFlags.libSignalWebSocketEnabled()) { + var network = new LibSignalNetwork(libSignalNetworkSupplier.get()); + return new LibSignalChatConnection( + "libsignal-unauth", + network.createChatService(null), + healthMonitor, + false); + } else { + return new OkHttpWebSocketConnection("unidentified", + signalServiceConfigurationSupplier.get(), + Optional.empty(), + BuildConfig.SIGNAL_AGENT, + healthMonitor, + Stories.isFeatureEnabled()); + } } }; } diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/CheckServiceReachabilityJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/CheckServiceReachabilityJob.kt index 81a1701d94..8a98a9c46f 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/CheckServiceReachabilityJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/CheckServiceReachabilityJob.kt @@ -9,7 +9,7 @@ import org.thoughtcrime.securesms.keyvalue.SignalStore import org.thoughtcrime.securesms.stories.Stories import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.internal.util.StaticCredentialsProvider -import org.whispersystems.signalservice.internal.websocket.WebSocketConnection +import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection import java.util.Optional import java.util.concurrent.TimeUnit @@ -64,7 +64,7 @@ class CheckServiceReachabilityJob private constructor(params: Parameters) : Base SignalStore.misc().lastCensorshipServiceReachabilityCheckTime = System.currentTimeMillis() - val uncensoredWebsocket = WebSocketConnection( + val uncensoredWebsocket = OkHttpWebSocketConnection( "uncensored-test", ApplicationDependencies.getSignalServiceNetworkAccess().uncensoredConfiguration, Optional.of( diff --git a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java index 9c7ed91dc1..8686b256c2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java +++ b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java @@ -11,7 +11,7 @@ import org.whispersystems.signalservice.api.util.Preconditions; import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.websocket.HealthMonitor; import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState; -import org.whispersystems.signalservice.internal.websocket.WebSocketConnection; +import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection; import java.util.concurrent.Executor; import java.util.concurrent.Executors; @@ -33,7 +33,7 @@ public final class SignalWebSocketHealthMonitor implements HealthMonitor { /** * This is the amount of time in between sent keep alives. Must be greater than {@link SignalWebSocketHealthMonitor#KEEP_ALIVE_TIMEOUT} */ - private static final long KEEP_ALIVE_SEND_CADENCE = TimeUnit.SECONDS.toMillis(WebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS); + private static final long KEEP_ALIVE_SEND_CADENCE = TimeUnit.SECONDS.toMillis(OkHttpWebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS); /** * This is the amount of time we will wait for a response to the keep alive before we consider the websockets dead. diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java b/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java index e83ee017fa..b555f7a8e3 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/FeatureFlags.java @@ -129,6 +129,7 @@ public final class FeatureFlags { private static final String MESSAGE_BACKUPS = "android.messageBackups"; private static final String CAMERAX_CUSTOM_CONTROLLER = "android.cameraXCustomController"; private static final String REGISTRATION_V2 = "android.registration.v2"; + private static final String LIBSIGNAL_WEB_SOCKET_ENABLED = "android.libsignalWebSocketEnabled"; /** * We will only store remote values for flags in this set. If you want a flag to be controllable @@ -208,7 +209,8 @@ public final class FeatureFlags { CDSI_LIBSIGNAL_NET, RX_MESSAGE_SEND, LINKED_DEVICE_LIFESPAN_SECONDS, - CAMERAX_CUSTOM_CONTROLLER + CAMERAX_CUSTOM_CONTROLLER, + LIBSIGNAL_WEB_SOCKET_ENABLED ); @VisibleForTesting @@ -754,6 +756,9 @@ public final class FeatureFlags { return getBoolean(REGISTRATION_V2, false); } + /** Whether unauthenticated chat web socket is backed by libsignal-net */ + public static boolean libSignalWebSocketEnabled() { return getBoolean(LIBSIGNAL_WEB_SOCKET_ENABLED, false); } + /** Only for rendering debug info. */ public static synchronized @NonNull Map getMemoryValues() { return new TreeMap<>(REMOTE_VALUES); diff --git a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.java b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.java index b600822225..88df57b100 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.java +++ b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.java @@ -181,7 +181,7 @@ public class MockApplicationDependencyProvider implements ApplicationDependencie } @Override - public @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier) { + public @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { return null; } diff --git a/libsignal-service/build.gradle.kts b/libsignal-service/build.gradle.kts index 14b411fe43..8ef56c93a8 100644 --- a/libsignal-service/build.gradle.kts +++ b/libsignal-service/build.gradle.kts @@ -95,6 +95,7 @@ dependencies { testImplementation(testLibs.assertj.core) testImplementation(testLibs.conscrypt.openjdk.uber) testImplementation(testLibs.mockito.core) + testImplementation(testLibs.mockk) testFixturesImplementation(libs.libsignal.client) testFixturesImplementation(testLibs.junit.junit) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt new file mode 100644 index 0000000000..651294d522 --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -0,0 +1,217 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.internal.websocket + +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers +import io.reactivex.rxjava3.subjects.BehaviorSubject +import io.reactivex.rxjava3.subjects.SingleSubject +import org.signal.core.util.logging.Log +import org.signal.libsignal.internal.CompletableFuture +import org.signal.libsignal.net.ChatService +import org.whispersystems.signalservice.api.websocket.HealthMonitor +import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState +import java.time.Instant +import java.util.Optional +import kotlin.time.Duration.Companion.seconds +import org.signal.libsignal.net.ChatService.Request as LibSignalRequest +import org.signal.libsignal.net.ChatService.Response as LibSignalResponse + +/** + * Implements the WebSocketConnection interface via libsignal-net + * + * Notable implementation choices: + * - [chatService] contains both the authenticated and unauthenticated connections, + * which one to use for [sendRequest]/[sendResponse] is based on [isAuthenticated]. + * - keep-alive requests always use the [org.signal.libsignal.net.ChatService.unauthenticatedSendAndDebug] + * API, and log the debug info on success. + * - regular sends use [org.signal.libsignal.net.ChatService.unauthenticatedSend] and don't create any overhead. + * - [org.whispersystems.signalservice.api.websocket.WebSocketConnectionState] reporting is implemented + * as close as possible to the original implementation in + * [org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection]. + */ +class LibSignalChatConnection( + name: String, + private val chatService: ChatService, + private val healthMonitor: HealthMonitor, + val isAuthenticated: Boolean +) : WebSocketConnection { + + companion object { + private val TAG = Log.tag(LibSignalChatConnection::class.java) + private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds + + private val KEEP_ALIVE_REQUEST = LibSignalRequest( + "GET", + "/v1/keepalive", + emptyMap(), + ByteArray(0), + SEND_TIMEOUT.toInt() + ) + } + + override val name = "[$name:${System.identityHashCode(this)}]" + + val state = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) + + override fun connect(): Observable { + Log.i(TAG, "$name Connecting...") + state.onNext(WebSocketConnectionState.CONNECTING) + val connect = if (isAuthenticated) { + chatService::connectAuthenticated + } else { + chatService::connectUnauthenticated + } + connect() + .whenComplete( + onSuccess = { debugInfo -> + Log.i(TAG, "$name Connected") + Log.d(TAG, "$name $debugInfo") + state.onNext(WebSocketConnectionState.CONNECTED) + }, + onFailure = { throwable -> + // TODO: [libsignal-net] Report WebSocketConnectionState.AUTHENTICATION_FAILED for 401 and 403 errors + Log.d(TAG, "$name Connect failed", throwable) + state.onNext(WebSocketConnectionState.FAILED) + } + ) + return state + } + + override fun isDead(): Boolean = false + + override fun disconnect() { + Log.i(TAG, "$name Disconnecting...") + state.onNext(WebSocketConnectionState.DISCONNECTING) + chatService.disconnect() + .whenComplete( + onSuccess = { + Log.i(TAG, "$name Disconnected") + state.onNext(WebSocketConnectionState.DISCONNECTED) + }, + onFailure = { throwable -> + Log.d(TAG, "$name Disconnect failed", throwable) + state.onNext(WebSocketConnectionState.DISCONNECTED) + } + ) + } + + override fun sendRequest(request: WebSocketRequestMessage): Single { + val single = SingleSubject.create() + val internalRequest = request.toLibSignalRequest() + val send = if (isAuthenticated) { + throw NotImplementedError("Authenticated socket is not yet supported") + } else { + chatService::unauthenticatedSend + } + send(internalRequest) + .whenComplete( + onSuccess = { response -> + when (response!!.status) { + in 400..599 -> { + healthMonitor.onMessageError(response.status, false) + } + } + // Here success means "we received the response" even if it is reporting an error. + // This is consistent with the behavior of the OkHttpWebSocketConnection. + single.onSuccess(response.toWebsocketResponse(isUnidentified = !isAuthenticated)) + }, + onFailure = { throwable -> + Log.i(TAG, "$name sendRequest failed", throwable) + single.onError(throwable) + } + ) + return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) + } + + override fun sendKeepAlive() { + Log.i(TAG, "$name Sending keep alive...") + val send = if (isAuthenticated) { + throw NotImplementedError("Authenticated socket is not yet supported") + } else { + chatService::unauthenticatedSendAndDebug + } + send(KEEP_ALIVE_REQUEST) + .whenComplete( + onSuccess = { debugResponse -> + Log.i(TAG, "$name Keep alive - success") + Log.d(TAG, "$name $debugResponse") + when (debugResponse!!.response.status) { + in 200..299 -> { + healthMonitor.onKeepAliveResponse( + Instant.now().toEpochMilli(), // ignored. can be any value + false + ) + } + + in 400..599 -> { + healthMonitor.onMessageError(debugResponse.response.status, isAuthenticated) + } + + else -> { + Log.w(TAG, "$name Unsupported keep alive response status: ${debugResponse.response.status}") + } + } + }, + onFailure = { throwable -> + Log.i(TAG, "$name Keep alive - failed") + Log.d(TAG, "$name $throwable") + state.onNext(WebSocketConnectionState.DISCONNECTED) + } + ) + } + + override fun readRequestIfAvailable(): Optional { + throw NotImplementedError() + } + + override fun readRequest(timeoutMillis: Long): WebSocketRequestMessage { + throw NotImplementedError() + } + + override fun sendResponse(response: WebSocketResponseMessage?) { + throw NotImplementedError() + } + + private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Long = SEND_TIMEOUT): LibSignalRequest { + return LibSignalRequest( + this.verb?.uppercase() ?: "GET", + this.path ?: "", + this.headers.associate { + val parts = it.split(':', limit = 2) + if (parts.size != 2) { + throw IllegalArgumentException("Headers must contain at least one colon") + } + parts[0] to parts[1] + }, + this.body?.toByteArray() ?: byteArrayOf(), + timeout.toInt() + ) + } + + private fun LibSignalResponse.toWebsocketResponse(isUnidentified: Boolean): WebsocketResponse { + return WebsocketResponse( + this.status, + this.body.decodeToString(), + this.headers, + isUnidentified + ) + } + + private fun CompletableFuture.whenComplete( + onSuccess: ((T?) -> Unit), + onFailure: ((Throwable) -> Unit) + ): CompletableFuture { + return this.whenComplete { value, throwable -> + if (throwable != null) { + onFailure(throwable) + } else { + onSuccess(value) + } + } + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetwork.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetwork.kt new file mode 100644 index 0000000000..625ad2f44c --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalNetwork.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.internal.websocket + +import org.signal.libsignal.net.ChatService +import org.signal.libsignal.net.Network +import org.whispersystems.signalservice.api.util.CredentialsProvider + +/** + * Makes Network API more ergonomic to use with Android client types + */ +class LibSignalNetwork(private val inner: Network) { + fun createChatService( + credentialsProvider: CredentialsProvider? = null + ): ChatService { + val username = credentialsProvider?.username ?: "" + val password = credentialsProvider?.password ?: "" + return inner.createChatService(username, password) + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java similarity index 92% rename from libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java rename to libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java index e39142812e..305e69624d 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java @@ -2,7 +2,6 @@ package org.whispersystems.signalservice.internal.websocket; import org.signal.libsignal.protocol.logging.Log; import org.signal.libsignal.protocol.util.Pair; -import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.TrustStore; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.Tls12SocketFactory; @@ -25,7 +24,6 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -51,10 +49,10 @@ import okhttp3.WebSocket; import okhttp3.WebSocketListener; import okio.ByteString; -public class WebSocketConnection extends WebSocketListener { +public class OkHttpWebSocketConnection extends WebSocketListener implements WebSocketConnection { - private static final String TAG = WebSocketConnection.class.getSimpleName(); - public static final int KEEPALIVE_FREQUENCY_SECONDS = 30; + private static final String TAG = OkHttpWebSocketConnection.class.getSimpleName(); + public static final int KEEPALIVE_FREQUENCY_SECONDS = 30; private final LinkedList incomingRequests = new LinkedList<>(); private final Map outgoingRequests = new HashMap<>(); @@ -76,22 +74,22 @@ public class WebSocketConnection extends WebSocketListener { private WebSocket client; - public WebSocketConnection(String name, - SignalServiceConfiguration serviceConfiguration, - Optional credentialsProvider, - String signalAgent, - HealthMonitor healthMonitor, - boolean allowStories) { + public OkHttpWebSocketConnection(String name, + SignalServiceConfiguration serviceConfiguration, + Optional credentialsProvider, + String signalAgent, + HealthMonitor healthMonitor, + boolean allowStories) { this(name, serviceConfiguration, credentialsProvider, signalAgent, healthMonitor, "", allowStories); } - public WebSocketConnection(String name, - SignalServiceConfiguration serviceConfiguration, - Optional credentialsProvider, - String signalAgent, - HealthMonitor healthMonitor, - String extraPathUri, - boolean allowStories) + public OkHttpWebSocketConnection(String name, + SignalServiceConfiguration serviceConfiguration, + Optional credentialsProvider, + String signalAgent, + HealthMonitor healthMonitor, + String extraPathUri, + boolean allowStories) { this.name = "[" + name + ":" + System.identityHashCode(this) + "]"; this.trustStore = serviceConfiguration.getSignalServiceUrls()[0].getTrustStore(); @@ -108,6 +106,7 @@ public class WebSocketConnection extends WebSocketListener { this.random = new SecureRandom(); } + @Override public String getName() { return name; } @@ -123,6 +122,7 @@ public class WebSocketConnection extends WebSocketListener { } } + @Override public synchronized Observable connect() { log("connect()"); @@ -130,7 +130,7 @@ public class WebSocketConnection extends WebSocketListener { Pair connectionInfo = getConnectionInfo(); SignalServiceUrl serviceUrl = connectionInfo.first(); String wsUri = connectionInfo.second(); - String filledUri; + String filledUri; if (credentialsProvider.isPresent()) { filledUri = String.format(wsUri, credentialsProvider.get().getUsername(), credentialsProvider.get().getPassword()); @@ -177,10 +177,12 @@ public class WebSocketConnection extends WebSocketListener { return webSocketState; } + @Override public synchronized boolean isDead() { return client == null; } + @Override public synchronized void disconnect() { log("disconnect()"); @@ -193,6 +195,7 @@ public class WebSocketConnection extends WebSocketListener { notifyAll(); } + @Override public synchronized Optional readRequestIfAvailable() { if (incomingRequests.size() > 0) { return Optional.of(incomingRequests.removeFirst()); @@ -201,6 +204,7 @@ public class WebSocketConnection extends WebSocketListener { } } + @Override public synchronized WebSocketRequestMessage readRequest(long timeoutMillis) throws TimeoutException, IOException { @@ -223,6 +227,7 @@ public class WebSocketConnection extends WebSocketListener { } } + @Override public synchronized Single sendRequest(WebSocketRequestMessage request) throws IOException { if (client == null) { throw new IOException("No connection!"); @@ -246,6 +251,7 @@ public class WebSocketConnection extends WebSocketListener { .timeout(10, TimeUnit.SECONDS, Schedulers.io()); } + @Override public synchronized void sendResponse(WebSocketResponseMessage response) throws IOException { if (client == null) { throw new IOException("Connection closed!"); @@ -261,9 +267,10 @@ public class WebSocketConnection extends WebSocketListener { } } + @Override public synchronized void sendKeepAlive() throws IOException { if (client != null) { - log( "Sending keep alive..."); + log("Sending keep alive..."); long id = System.currentTimeMillis(); byte[] message = new WebSocketMessage.Builder() .type(WebSocketMessage.Type.REQUEST) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt new file mode 100644 index 0000000000..d6e0e1197f --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt @@ -0,0 +1,39 @@ +package org.whispersystems.signalservice.internal.websocket + +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState +import java.io.IOException +import java.util.Optional +import java.util.concurrent.TimeoutException + +/** + * Common interface for the web socket connection API + * + * At the time of this writing there are two implementations available: + * - OkHttpWebSocketConnection - the original Android client implementation in Java using OkHttp library + * - LibSignalChatConnection - the wrapper around libsignal's [org.signal.libsignal.net.ChatService] + */ +interface WebSocketConnection { + val name: String + + fun connect(): Observable + + fun isDead(): Boolean + + fun disconnect() + + @Throws(IOException::class) + fun sendRequest(request: WebSocketRequestMessage): Single + + @Throws(IOException::class) + fun sendKeepAlive() + + fun readRequestIfAvailable(): Optional + + @Throws(TimeoutException::class, IOException::class) + fun readRequest(timeoutMillis: Long): WebSocketRequestMessage + + @Throws(IOException::class) + fun sendResponse(response: WebSocketResponseMessage?) +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java index 79ae17c9b2..940aa0aba8 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebsocketResponse.java @@ -15,9 +15,13 @@ public class WebsocketResponse { private final boolean unidentified; WebsocketResponse(int status, String body, List headers, boolean unidentified) { + this(status, body, parseHeaders(headers), unidentified); + } + + WebsocketResponse(int status, String body, Map headerMap, boolean unidentified) { this.status = status; this.body = body; - this.headers = parseHeaders(headers); + this.headers = headerMap; this.unidentified = unidentified; } @@ -41,7 +45,7 @@ public class WebsocketResponse { Map headers = new HashMap<>(rawHeaders.size()); for (String raw : rawHeaders) { - if (raw != null && raw.length() > 0) { + if (raw != null && !raw.isEmpty()) { int colonIndex = raw.indexOf(":"); if (colonIndex > 0 && colonIndex < raw.length() - 1) { diff --git a/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt new file mode 100644 index 0000000000..63350eacc7 --- /dev/null +++ b/libsignal-service/src/test/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnectionTest.kt @@ -0,0 +1,251 @@ +package org.whispersystems.signalservice.internal.websocket + +import io.mockk.clearAllMocks +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.rxjava3.observers.TestObserver +import org.junit.Before +import org.junit.Test +import org.signal.libsignal.internal.CompletableFuture +import org.signal.libsignal.net.ChatService +import org.signal.libsignal.net.ChatService.DebugInfo +import org.signal.libsignal.net.IpType +import org.whispersystems.signalservice.api.websocket.HealthMonitor +import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState +import java.util.concurrent.CountDownLatch +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import org.signal.libsignal.net.ChatService.Response as LibSignalResponse +import org.signal.libsignal.net.ChatService.ResponseAndDebugInfo as LibSignalDebugResponse + +class LibSignalChatConnectionTest { + + private val executor: ExecutorService = Executors.newSingleThreadExecutor() + private val healthMonitor = mockk() + private val chatService = mockk() + private val connection = LibSignalChatConnection("test", chatService, healthMonitor, isAuthenticated = false) + + @Before + fun before() { + clearAllMocks() + every { healthMonitor.onMessageError(any(), any()) } + every { healthMonitor.onKeepAliveResponse(any(), any()) } + } + + @Test + fun orderOfStatesOnSuccessfulConnect() { + val latch = CountDownLatch(1) + + every { chatService.connectUnauthenticated() } answers { + delay { + it.complete(DEBUG_INFO) + latch.countDown() + } + } + + val observer = TestObserver() + connection.state.subscribe(observer) + + connection.connect() + + latch.await(100, TimeUnit.MILLISECONDS) + + observer.assertNotComplete() + observer.assertValues( + WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.CONNECTING, + WebSocketConnectionState.CONNECTED + ) + } + + @Test + fun orderOfStatesOnConnectionFailure() { + val connectionException = RuntimeException("connect failed") + val latch = CountDownLatch(1) + + every { chatService.connectUnauthenticated() } answers { + delay { + it.completeExceptionally(connectionException) + } + } + + val observer = TestObserver() + connection.state.subscribe(observer) + + connection.connect() + + latch.await(100, TimeUnit.MILLISECONDS) + + observer.assertNotComplete() + observer.assertValues( + WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.CONNECTING, + WebSocketConnectionState.FAILED + ) + } + + @Test + fun orderOfStatesOnConnectAndDisconnect() { + val connectLatch = CountDownLatch(1) + val disconnectLatch = CountDownLatch(1) + + every { chatService.connectUnauthenticated() } answers { + delay { + it.complete(DEBUG_INFO) + connectLatch.countDown() + } + } + every { chatService.disconnect() } answers { + delay { + it.complete(null) + disconnectLatch.countDown() + } + } + + val observer = TestObserver() + + connection.state.subscribe(observer) + + connection.connect() + connectLatch.await(100, TimeUnit.MILLISECONDS) + connection.disconnect() + disconnectLatch.await(100, TimeUnit.MILLISECONDS) + + observer.assertNotComplete() + observer.assertValues( + WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.CONNECTING, + WebSocketConnectionState.CONNECTED, + WebSocketConnectionState.DISCONNECTING, + WebSocketConnectionState.DISCONNECTED + ) + } + + @Test + fun orderOfStatesOnDisconnectFailure() { + val disconnectException = RuntimeException("disconnect failed") + + val latch = CountDownLatch(1) + + every { chatService.disconnect() } answers { + delay { + it.completeExceptionally(disconnectException) + } + } + + val observer = TestObserver() + + connection.state.subscribe(observer) + + connection.disconnect() + + latch.await(100, TimeUnit.MILLISECONDS) + + observer.assertNotComplete() + observer.assertValues( + WebSocketConnectionState.DISCONNECTED, + WebSocketConnectionState.DISCONNECTING, + WebSocketConnectionState.DISCONNECTED + ) + } + + @Test + fun keepAliveSuccess() { + val latch = CountDownLatch(1) + + every { chatService.unauthenticatedSendAndDebug(any()) } answers { + delay { + it.complete(make_debug_response(RESPONSE_SUCCESS)) + latch.countDown() + } + } + + connection.sendKeepAlive() + + latch.await(100, TimeUnit.MILLISECONDS) + + verify(exactly = 1) { + healthMonitor.onKeepAliveResponse(any(), false) + } + verify(exactly = 0) { + healthMonitor.onMessageError(any(), any()) + } + } + + @Test + fun keepAliveFailure() { + for (response in listOf(RESPONSE_ERROR, RESPONSE_SERVER_ERROR)) { + val latch = CountDownLatch(1) + + every { chatService.unauthenticatedSendAndDebug(any()) } answers { + delay { + it.complete(make_debug_response(response)) + } + } + + connection.sendKeepAlive() + latch.await(100, TimeUnit.MILLISECONDS) + + verify(exactly = 1) { + healthMonitor.onMessageError(response.status, false) + } + verify(exactly = 0) { + healthMonitor.onKeepAliveResponse(any(), any()) + } + } + } + + @Test + fun keepAliveConnectionFailure() { + val connectionFailure = RuntimeException("Sending keep-alive failed") + val latch = CountDownLatch(1) + + every { + chatService.unauthenticatedSendAndDebug(any()) + } answers { + delay { + it.completeExceptionally(connectionFailure) + } + } + + val observer = TestObserver() + connection.state.subscribe(observer) + + connection.sendKeepAlive() + + latch.await(100, TimeUnit.MILLISECONDS) + + observer.assertNotComplete() + observer.assertValues( + // This is the starting state + WebSocketConnectionState.DISCONNECTED, + // This one is the result of a keep-alive failure + WebSocketConnectionState.DISCONNECTED + ) + verify(exactly = 0) { + healthMonitor.onKeepAliveResponse(any(), any()) + healthMonitor.onMessageError(any(), any()) + } + } + + private fun delay(action: ((CompletableFuture) -> Unit)): CompletableFuture { + val future = CompletableFuture() + executor.submit { + action(future) + } + return future + } + + companion object { + private val DEBUG_INFO: DebugInfo = DebugInfo(0, IpType.UNKNOWN, 100, "") + private val RESPONSE_SUCCESS = LibSignalResponse(200, "", emptyMap(), byteArrayOf()) + private val RESPONSE_ERROR = LibSignalResponse(400, "", emptyMap(), byteArrayOf()) + private val RESPONSE_SERVER_ERROR = LibSignalResponse(500, "", emptyMap(), byteArrayOf()) + + private fun make_debug_response(response: LibSignalResponse): LibSignalDebugResponse { + return LibSignalDebugResponse(response, DEBUG_INFO) + } + } +}