diff --git a/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt b/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt index c1c043592c..f540d0e7d2 100644 --- a/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt +++ b/app/src/androidTest/java/org/thoughtcrime/securesms/dependencies/InstrumentationApplicationDependencyProvider.kt @@ -26,8 +26,8 @@ import org.thoughtcrime.securesms.testing.runSync import org.thoughtcrime.securesms.testing.success import org.whispersystems.signalservice.api.SignalServiceDataStore import org.whispersystems.signalservice.api.SignalServiceMessageSender -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.push.TrustStore +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.configuration.SignalCdnUrl import org.whispersystems.signalservice.internal.configuration.SignalCdsiUrl import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration @@ -121,12 +121,13 @@ class InstrumentationApplicationDependencyProvider(val application: Application, } override fun provideSignalServiceMessageSender( - signalWebSocket: SignalWebSocket, + authWebSocket: SignalWebSocket.AuthenticatedWebSocket, + unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket ): SignalServiceMessageSender { if (signalServiceMessageSender == null) { - signalServiceMessageSender = spyk(objToCopy = default.provideSignalServiceMessageSender(signalWebSocket, protocolStore, pushServiceSocket)) + signalServiceMessageSender = spyk(objToCopy = default.provideSignalServiceMessageSender(authWebSocket, unauthWebSocket, protocolStore, pushServiceSocket)) } return signalServiceMessageSender!! } diff --git a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java index 3f96244b83..4d2999e178 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java +++ b/app/src/main/java/org/thoughtcrime/securesms/ApplicationContext.java @@ -259,6 +259,7 @@ public class ApplicationContext extends Application implements AppForegroundObse checkFreeDiskSpace(); MemoryTracker.start(); BackupSubscriptionCheckJob.enqueueIfAble(); + AppDependencies.getUnauthWebSocket().setShouldSendKeepAlives(true); long lastForegroundTime = SignalStore.misc().getLastForegroundTime(); long currentTime = System.currentTimeMillis(); @@ -282,6 +283,7 @@ public class ApplicationContext extends Application implements AppForegroundObse AppDependencies.getFrameRateTracker().stop(); AppDependencies.getShakeToReport().disable(); AppDependencies.getDeadlockDetector().stop(); + AppDependencies.getUnauthWebSocket().setShouldSendKeepAlives(false); MemoryTracker.stop(); AnrDetector.stop(); } @@ -378,7 +380,7 @@ public class ApplicationContext extends Application implements AppForegroundObse } public void initializeMessageRetrieval() { - AppDependencies.getIncomingMessageObserver(); + AppDependencies.startNetwork(); } @VisibleForTesting diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt index bc12e3d9b6..f27fb93dcc 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/changenumber/ChangeNumberRepository.kt @@ -190,7 +190,7 @@ class ChangeNumberRepository( StorageSyncHelper.scheduleSyncForDataChange() AppDependencies.resetNetwork() - AppDependencies.incomingMessageObserver + AppDependencies.startNetwork() AppDependencies.jobManager.add(RefreshAttributesJob()) diff --git a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/privacy/advanced/AdvancedPrivacySettingsViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/privacy/advanced/AdvancedPrivacySettingsViewModel.kt index a125281459..3b9e4d84b9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/privacy/advanced/AdvancedPrivacySettingsViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/components/settings/app/privacy/advanced/AdvancedPrivacySettingsViewModel.kt @@ -93,7 +93,7 @@ class AdvancedPrivacySettingsViewModel( val isCountryCodeCensoredByDefault: Boolean = AppDependencies.signalServiceNetworkAccess.isCountryCodeCensoredByDefault(countryCode) val enabledState: SettingsValues.CensorshipCircumventionEnabled = SignalStore.settings.censorshipCircumventionEnabled val hasInternet: Boolean = NetworkConstraint.isMet(AppDependencies.application) - val websocketConnected: Boolean = AppDependencies.signalWebSocket.webSocketState.firstOrError().blockingGet() == WebSocketConnectionState.CONNECTED + val websocketConnected: Boolean = AppDependencies.authWebSocket.state.firstOrError().blockingGet() == WebSocketConnectionState.CONNECTED return when { SignalStore.internal.allowChangingCensorshipSetting -> { diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt index 549dd5ff75..7792a5c9b2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt @@ -42,7 +42,6 @@ import org.whispersystems.signalservice.api.SignalServiceAccountManager import org.whispersystems.signalservice.api.SignalServiceDataStore import org.whispersystems.signalservice.api.SignalServiceMessageReceiver import org.whispersystems.signalservice.api.SignalServiceMessageSender -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.archive.ArchiveApi import org.whispersystems.signalservice.api.attachment.AttachmentApi import org.whispersystems.signalservice.api.groupsv2.GroupsV2Operations @@ -53,6 +52,7 @@ import org.whispersystems.signalservice.api.services.CallLinksService import org.whispersystems.signalservice.api.services.DonationsService import org.whispersystems.signalservice.api.services.ProfileService import org.whispersystems.signalservice.api.storage.StorageServiceApi +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration import org.whispersystems.signalservice.internal.push.PushServiceSocket @@ -214,7 +214,7 @@ object AppDependencies { /** * An observable that emits the current state of the WebSocket connection across the various lifecycles - * of the [signalWebSocket]. + * of the [authWebSocket]. */ @JvmStatic val webSocketObserver: LatestValueObservable = LatestValueObservable(_webSocketObserver) @@ -253,8 +253,12 @@ object AppDependencies { get() = networkModule.libsignalNetwork @JvmStatic - val signalWebSocket: SignalWebSocket - get() = networkModule.signalWebSocket + val authWebSocket: SignalWebSocket.AuthenticatedWebSocket + get() = networkModule.authWebSocket + + @JvmStatic + val unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket + get() = networkModule.unauthWebSocket @JvmStatic val groupsV2Authorization: GroupsV2Authorization @@ -326,11 +330,16 @@ object AppDependencies { _networkModule.reset() } + @JvmStatic + fun startNetwork() { + networkModule.openConnections() + } + interface Provider { fun providePushServiceSocket(signalServiceConfiguration: SignalServiceConfiguration, groupsV2Operations: GroupsV2Operations): PushServiceSocket fun provideGroupsV2Operations(signalServiceConfiguration: SignalServiceConfiguration): GroupsV2Operations fun provideSignalServiceAccountManager(pushServiceSocket: PushServiceSocket, groupsV2Operations: GroupsV2Operations): SignalServiceAccountManager - fun provideSignalServiceMessageSender(signalWebSocket: SignalWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket): SignalServiceMessageSender + fun provideSignalServiceMessageSender(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket): SignalServiceMessageSender fun provideSignalServiceMessageReceiver(pushServiceSocket: PushServiceSocket): SignalServiceMessageReceiver fun provideSignalServiceNetworkAccess(): SignalServiceNetworkAccess fun provideRecipientCache(): LiveRecipientCache @@ -339,7 +348,7 @@ object AppDependencies { fun provideMegaphoneRepository(): MegaphoneRepository fun provideEarlyMessageCache(): EarlyMessageCache fun provideMessageNotifier(): MessageNotifier - fun provideIncomingMessageObserver(signalWebSocket: SignalWebSocket): IncomingMessageObserver + fun provideIncomingMessageObserver(webSocket: SignalWebSocket.AuthenticatedWebSocket): IncomingMessageObserver fun provideTrimThreadsByDateManager(): TrimThreadsByDateManager fun provideViewOnceMessageManager(): ViewOnceMessageManager fun provideExpiringStoriesManager(): ExpiringStoriesManager @@ -353,14 +362,13 @@ object AppDependencies { fun provideSignalCallManager(): SignalCallManager fun providePendingRetryReceiptManager(): PendingRetryReceiptManager fun providePendingRetryReceiptCache(): PendingRetryReceiptCache - fun provideSignalWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket fun provideProtocolStore(): SignalServiceDataStoreImpl fun provideGiphyMp4Cache(): GiphyMp4Cache fun provideExoPlayerPool(): SimpleExoPlayerPool fun provideAndroidCallAudioManager(): AudioManagerCompat fun provideDonationsService(pushServiceSocket: PushServiceSocket): DonationsService fun provideCallLinksService(pushServiceSocket: PushServiceSocket): CallLinksService - fun provideProfileService(profileOperations: ClientZkProfileOperations, signalServiceMessageReceiver: SignalServiceMessageReceiver, signalWebSocket: SignalWebSocket): ProfileService + fun provideProfileService(profileOperations: ClientZkProfileOperations, signalServiceMessageReceiver: SignalServiceMessageReceiver, authWebSocket: SignalWebSocket.AuthenticatedWebSocket, unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket): ProfileService fun provideDeadlockDetector(): DeadlockDetector fun provideClientZkReceiptOperations(signalServiceConfiguration: SignalServiceConfiguration): ClientZkReceiptOperations fun provideScheduledMessageManager(): ScheduledMessageManager @@ -368,9 +376,11 @@ object AppDependencies { fun provideBillingApi(): BillingApi fun provideArchiveApi(pushServiceSocket: PushServiceSocket): ArchiveApi fun provideKeysApi(pushServiceSocket: PushServiceSocket): KeysApi - fun provideAttachmentApi(signalWebSocket: SignalWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi + fun provideAttachmentApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi fun provideLinkDeviceApi(pushServiceSocket: PushServiceSocket): LinkDeviceApi fun provideRegistrationApi(pushServiceSocket: PushServiceSocket): RegistrationApi fun provideStorageServiceApi(pushServiceSocket: PushServiceSocket): StorageServiceApi + fun provideAuthWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket.AuthenticatedWebSocket + fun provideUnauthWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket.UnauthenticatedWebSocket } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 37b75cc342..4fb90ebded 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -81,7 +81,6 @@ import org.whispersystems.signalservice.api.SignalServiceAccountManager; import org.whispersystems.signalservice.api.SignalServiceDataStore; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; import org.whispersystems.signalservice.api.SignalServiceMessageSender; -import org.whispersystems.signalservice.api.SignalWebSocket; import org.whispersystems.signalservice.api.archive.ArchiveApi; import org.whispersystems.signalservice.api.attachment.AttachmentApi; import org.whispersystems.signalservice.api.groupsv2.ClientZkOperations; @@ -98,6 +97,8 @@ import org.whispersystems.signalservice.api.storage.StorageServiceApi; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.UptimeSleepTimer; +import org.whispersystems.signalservice.api.websocket.HealthMonitor; +import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.api.websocket.WebSocketFactory; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.push.PushServiceSocket; @@ -147,11 +148,12 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @Override - public @NonNull SignalServiceMessageSender provideSignalServiceMessageSender(@NonNull SignalWebSocket signalWebSocket, @NonNull SignalServiceDataStore protocolStore, @NonNull PushServiceSocket pushServiceSocket) { + public @NonNull SignalServiceMessageSender provideSignalServiceMessageSender(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket, @NonNull SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket, @NonNull SignalServiceDataStore protocolStore, @NonNull PushServiceSocket pushServiceSocket) { return new SignalServiceMessageSender(pushServiceSocket, protocolStore, ReentrantSessionLock.INSTANCE, - signalWebSocket, + authWebSocket, + unauthWebSocket, Optional.of(new SecurityEventListener(context)), SignalExecutors.newCachedBoundedExecutor("signal-messages", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 16, 30), ByteUnit.KILOBYTES.toBytes(256)); @@ -207,8 +209,8 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @Override - public @NonNull IncomingMessageObserver provideIncomingMessageObserver(@NonNull SignalWebSocket signalWebSocket) { - return new IncomingMessageObserver(context, signalWebSocket); + public @NonNull IncomingMessageObserver provideIncomingMessageObserver(@NonNull SignalWebSocket.AuthenticatedWebSocket webSocket) { + return new IncomingMessageObserver(context, webSocket); } @Override @@ -297,15 +299,29 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @Override - public @NonNull SignalWebSocket provideSignalWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { - SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); - SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(context, sleepTimer); - WebSocketShadowingBridge bridge = new DefaultWebSocketShadowingBridge(context); - SignalWebSocket signalWebSocket = new SignalWebSocket(provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier, bridge)); + public @NonNull SignalWebSocket.AuthenticatedWebSocket provideAuthWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { + SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); + SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); + WebSocketShadowingBridge bridge = new DefaultWebSocketShadowingBridge(context); + WebSocketFactory webSocketFactory = provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier, bridge); + SignalWebSocket.AuthenticatedWebSocket webSocket = new SignalWebSocket.AuthenticatedWebSocket(webSocketFactory::createWebSocket); - healthMonitor.monitor(signalWebSocket); + healthMonitor.monitor(webSocket); - return signalWebSocket; + return webSocket; + } + + @Override + public @NonNull SignalWebSocket.UnauthenticatedWebSocket provideUnauthWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { + SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); + SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); + WebSocketShadowingBridge bridge = new DefaultWebSocketShadowingBridge(context); + WebSocketFactory webSocketFactory = provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier, bridge); + SignalWebSocket.UnauthenticatedWebSocket webSocket = new SignalWebSocket.UnauthenticatedWebSocket(webSocketFactory::createUnidentifiedWebSocket); + + healthMonitor.monitor(webSocket); + + return webSocket; } @Override @@ -383,9 +399,10 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { @Override public @NonNull ProfileService provideProfileService(@NonNull ClientZkProfileOperations clientZkProfileOperations, @NonNull SignalServiceMessageReceiver receiver, - @NonNull SignalWebSocket signalWebSocket) + @NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket, + @NonNull SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket) { - return new ProfileService(clientZkProfileOperations, receiver, signalWebSocket); + return new ProfileService(clientZkProfileOperations, receiver, authWebSocket, unauthWebSocket); } @Override @@ -401,7 +418,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @NonNull WebSocketFactory provideWebSocketFactory(@NonNull Supplier signalServiceConfigurationSupplier, - @NonNull SignalWebSocketHealthMonitor healthMonitor, + @NonNull HealthMonitor healthMonitor, @NonNull Supplier libSignalNetworkSupplier, @NonNull WebSocketShadowingBridge bridge) { @@ -479,8 +496,8 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @Override - public @NonNull AttachmentApi provideAttachmentApi(@NonNull SignalWebSocket signalWebSocket, @NonNull PushServiceSocket pushServiceSocket) { - return new AttachmentApi(signalWebSocket, pushServiceSocket); + public @NonNull AttachmentApi provideAttachmentApi(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket, @NonNull PushServiceSocket pushServiceSocket) { + return new AttachmentApi(authWebSocket, pushServiceSocket); } @Override diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt index a5348627ae..572de3043c 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt @@ -26,7 +26,6 @@ import org.thoughtcrime.securesms.push.SignalServiceTrustStore import org.whispersystems.signalservice.api.SignalServiceAccountManager import org.whispersystems.signalservice.api.SignalServiceMessageReceiver import org.whispersystems.signalservice.api.SignalServiceMessageSender -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.archive.ArchiveApi import org.whispersystems.signalservice.api.attachment.AttachmentApi import org.whispersystems.signalservice.api.groupsv2.GroupsV2Operations @@ -39,6 +38,7 @@ import org.whispersystems.signalservice.api.services.DonationsService import org.whispersystems.signalservice.api.services.ProfileService import org.whispersystems.signalservice.api.storage.StorageServiceApi import org.whispersystems.signalservice.api.util.Tls12SocketFactory +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.internal.push.PushServiceSocket import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager @@ -70,12 +70,12 @@ class NetworkDependenciesModule( val protocolStore: SignalServiceDataStoreImpl by _protocolStore private val _signalServiceMessageSender = resettableLazy { - provider.provideSignalServiceMessageSender(signalWebSocket, protocolStore, pushServiceSocket) + provider.provideSignalServiceMessageSender(authWebSocket, unauthWebSocket, protocolStore, pushServiceSocket) } val signalServiceMessageSender: SignalServiceMessageSender by _signalServiceMessageSender val incomingMessageObserver: IncomingMessageObserver by lazy { - provider.provideIncomingMessageObserver(signalWebSocket) + provider.provideIncomingMessageObserver(authWebSocket) } val pushServiceSocket: PushServiceSocket by lazy { @@ -90,12 +90,16 @@ class NetworkDependenciesModule( provider.provideLibsignalNetwork(signalServiceNetworkAccess.getConfiguration()) } - val signalWebSocket: SignalWebSocket by lazy { - provider.provideSignalWebSocket({ signalServiceNetworkAccess.getConfiguration() }, { libsignalNetwork }).also { - disposables += it.webSocketState.subscribe { webSocketStateSubject.onNext(it) } + val authWebSocket: SignalWebSocket.AuthenticatedWebSocket by lazy { + provider.provideAuthWebSocket({ signalServiceNetworkAccess.getConfiguration() }, { libsignalNetwork }).also { + disposables += it.state.subscribe { s -> webSocketStateSubject.onNext(s) } } } + val unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket by lazy { + provider.provideUnauthWebSocket({ signalServiceNetworkAccess.getConfiguration() }, { libsignalNetwork }) + } + val groupsV2Authorization: GroupsV2Authorization by lazy { val authCache: GroupsV2Authorization.ValueCache = GroupsV2AuthorizationMemoryValueCache(SignalStore.groupsV2AciAuthorizationCache) GroupsV2Authorization(signalServiceAccountManager.groupsV2Api, authCache) @@ -122,7 +126,7 @@ class NetworkDependenciesModule( } val profileService: ProfileService by lazy { - provider.provideProfileService(groupsV2Operations.profileOperations, signalServiceMessageReceiver, signalWebSocket) + provider.provideProfileService(groupsV2Operations.profileOperations, signalServiceMessageReceiver, authWebSocket, unauthWebSocket) } val donationsService: DonationsService by lazy { @@ -138,7 +142,7 @@ class NetworkDependenciesModule( } val attachmentApi: AttachmentApi by lazy { - provider.provideAttachmentApi(signalWebSocket, pushServiceSocket) + provider.provideAttachmentApi(authWebSocket, pushServiceSocket) } val linkDeviceApi: LinkDeviceApi by lazy { @@ -185,9 +189,15 @@ class NetworkDependenciesModule( if (_signalServiceMessageSender.isInitialized()) { signalServiceMessageSender.cancelInFlightRequests() } + unauthWebSocket.disconnect() disposables.clear() } + fun openConnections() { + incomingMessageObserver + unauthWebSocket.connect() + } + fun resetProtocolStores() { _protocolStore.reset() _signalServiceMessageSender.reset() diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentUploadJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentUploadJob.kt index d4fcf01a35..5be31299a9 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentUploadJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/AttachmentUploadJob.kt @@ -196,6 +196,7 @@ class AttachmentUploadJob private constructor( if (lastReset > now || lastReset + NETWORK_RESET_THRESHOLD > now) { Log.w(TAG, "Our existing connections is getting repeatedly denied by the server, reset network to establish new connections") AppDependencies.resetNetwork() + AppDependencies.startNetwork() SignalStore.misc.lastNetworkResetDueToStreamResets = now } else { Log.i(TAG, "Stream reset during upload, not resetting network yet, last reset: $lastReset") diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index feda8e2072..8df7e69f06 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -33,10 +33,10 @@ import org.thoughtcrime.securesms.util.AppForegroundObserver import org.thoughtcrime.securesms.util.RemoteConfig import org.thoughtcrime.securesms.util.SignalLocalMetrics import org.thoughtcrime.securesms.util.asChain -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.util.SleepTimer import org.whispersystems.signalservice.api.util.UptimeSleepTimer +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException import org.whispersystems.signalservice.internal.push.Envelope @@ -54,10 +54,9 @@ import kotlin.time.Duration.Companion.seconds /** * The application-level manager of our websocket connection. * - * * This class is responsible for opening/closing the websocket based on the app's state and observing new inbound messages received on the websocket. */ -class IncomingMessageObserver(private val context: Application, private val signalWebSocket: SignalWebSocket) { +class IncomingMessageObserver(private val context: Application, private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket) { companion object { private val TAG = Log.tag(IncomingMessageObserver::class.java) @@ -244,7 +243,7 @@ class IncomingMessageObserver(private val context: Application, private val sign } private fun disconnect() { - signalWebSocket.disconnect() + authWebSocket.disconnect() } @JvmOverloads @@ -384,7 +383,7 @@ class IncomingMessageObserver(private val context: Application, private val sign waitForConnectionNecessary() Log.i(TAG, "Making websocket connection....") - val webSocketDisposable = signalWebSocket.webSocketState.subscribe { state: WebSocketConnectionState -> + val webSocketDisposable = authWebSocket.state.subscribe { state: WebSocketConnectionState -> Log.d(TAG, "WebSocket State: $state") // Any change to a non-connected state means that we are not drained @@ -397,13 +396,13 @@ class IncomingMessageObserver(private val context: Application, private val sign } } - signalWebSocket.connect() + authWebSocket.connect() try { while (!terminated && isConnectionNecessary()) { try { Log.d(TAG, "Reading message...") - val hasMore = signalWebSocket.readMessageBatch(websocketReadTimeout, 30) { batch -> + val hasMore = authWebSocket.readMessageBatch(websocketReadTimeout, 30) { batch -> Log.i(TAG, "Retrieved ${batch.size} envelopes!") val bufferedStore = BufferedProtocolStore.create() @@ -425,7 +424,7 @@ class IncomingMessageObserver(private val context: Application, private val sign AppDependencies.jobManager.addAllChains(jobs) } - signalWebSocket.sendAck(response) + authWebSocket.sendAck(response) } } } @@ -448,7 +447,7 @@ class IncomingMessageObserver(private val context: Application, private val sign } } catch (e: WebSocketUnavailableException) { Log.i(TAG, "Pipe unexpectedly unavailable, connecting") - signalWebSocket.connect() + authWebSocket.connect() } catch (e: TimeoutException) { Log.w(TAG, "Application level read timeout...") attempts = 0 diff --git a/app/src/main/java/org/thoughtcrime/securesms/net/DeviceTransferBlockingInterceptor.java b/app/src/main/java/org/thoughtcrime/securesms/net/DeviceTransferBlockingInterceptor.java index 692ca3c901..87eb617c45 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/net/DeviceTransferBlockingInterceptor.java +++ b/app/src/main/java/org/thoughtcrime/securesms/net/DeviceTransferBlockingInterceptor.java @@ -55,6 +55,6 @@ public final class DeviceTransferBlockingInterceptor implements Interceptor { public void unblockNetwork() { blockNetworking = false; - AppDependencies.getIncomingMessageObserver(); + AppDependencies.startNetwork(); } } diff --git a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java deleted file mode 100644 index cd6536f3f5..0000000000 --- a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.java +++ /dev/null @@ -1,207 +0,0 @@ -package org.thoughtcrime.securesms.net; - -import android.app.Application; - -import androidx.annotation.NonNull; - -import org.signal.core.util.logging.Log; -import org.thoughtcrime.securesms.util.TextSecurePreferences; -import org.whispersystems.signalservice.api.SignalWebSocket; -import org.whispersystems.signalservice.api.util.Preconditions; -import org.whispersystems.signalservice.api.util.SleepTimer; -import org.whispersystems.signalservice.api.websocket.HealthMonitor; -import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState; -import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection; - -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; - -import io.reactivex.rxjava3.schedulers.Schedulers; - -/** - * Monitors the health of the identified and unidentified WebSockets. If either one appears to be - * unhealthy, will trigger restarting both. - *

- * The monitor is also responsible for sending heartbeats/keep-alive messages to prevent - * timeouts. - */ -public final class SignalWebSocketHealthMonitor implements HealthMonitor { - - private static final String TAG = Log.tag(SignalWebSocketHealthMonitor.class); - - /** - * This is the amount of time in between sent keep alives. Must be greater than {@link SignalWebSocketHealthMonitor#KEEP_ALIVE_TIMEOUT} - */ - private static final long KEEP_ALIVE_SEND_CADENCE = TimeUnit.SECONDS.toMillis(OkHttpWebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS); - - /** - * This is the amount of time we will wait for a response to the keep alive before we consider the websockets dead. - * It is required that this value be less than {@link SignalWebSocketHealthMonitor#KEEP_ALIVE_SEND_CADENCE} - */ - private static final long KEEP_ALIVE_TIMEOUT = TimeUnit.SECONDS.toMillis(20); - - private final Executor executor = Executors.newSingleThreadExecutor(); - - private final Application context; - private SignalWebSocket signalWebSocket; - private final SleepTimer sleepTimer; - - private KeepAliveSender keepAliveSender; - - private final HealthState identified = new HealthState(); - private final HealthState unidentified = new HealthState(); - - public SignalWebSocketHealthMonitor(@NonNull Application context, @NonNull SleepTimer sleepTimer) { - this.context = context; - this.sleepTimer = sleepTimer; - } - - public void monitor(@NonNull SignalWebSocket signalWebSocket) { - executor.execute(() -> { - Preconditions.checkNotNull(signalWebSocket); - Preconditions.checkArgument(this.signalWebSocket == null, "monitor can only be called once"); - - this.signalWebSocket = signalWebSocket; - - //noinspection ResultOfMethodCallIgnored - signalWebSocket.getWebSocketState() - .subscribeOn(Schedulers.computation()) - .observeOn(Schedulers.computation()) - .distinctUntilChanged() - .subscribe(s -> onStateChange(s, identified, true)); - - //noinspection ResultOfMethodCallIgnored - signalWebSocket.getUnidentifiedWebSocketState() - .subscribeOn(Schedulers.computation()) - .observeOn(Schedulers.computation()) - .distinctUntilChanged() - .subscribe(s -> onStateChange(s, unidentified, false)); - }); - } - - private void onStateChange(WebSocketConnectionState connectionState, HealthState healthState, boolean isIdentified) { - executor.execute(() -> { - switch (connectionState) { - case CONNECTED: - if (isIdentified) { - TextSecurePreferences.setUnauthorizedReceived(context, false); - break; - } - case AUTHENTICATION_FAILED: - if (isIdentified) { - TextSecurePreferences.setUnauthorizedReceived(context, true); - break; - } - case FAILED: - break; - } - - healthState.needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED; - - if (keepAliveSender == null && isKeepAliveNecessary()) { - keepAliveSender = new KeepAliveSender(); - keepAliveSender.start(); - } else if (keepAliveSender != null && !isKeepAliveNecessary()) { - keepAliveSender.shutdown(); - keepAliveSender = null; - } - }); - } - - @Override - public void onKeepAliveResponse(long sentTimestamp, boolean isIdentifiedWebSocket) { - final long keepAliveTime = System.currentTimeMillis(); - executor.execute(() -> { - if (isIdentifiedWebSocket) { - identified.lastKeepAliveReceived = keepAliveTime; - } else { - unidentified.lastKeepAliveReceived = keepAliveTime; - } - }); - } - - @Override - public void onMessageError(int status, boolean isIdentifiedWebSocket) { - executor.execute(() -> { - if (status == 409) { - HealthState healthState = (isIdentifiedWebSocket ? identified : unidentified); - if (healthState.mismatchErrorTracker.addSample(System.currentTimeMillis())) { - Log.w(TAG, "Received too many mismatch device errors, forcing new websockets."); - signalWebSocket.forceNewWebSockets(); - } - } - }); - } - - private boolean isKeepAliveNecessary() { - return identified.needsKeepAlive || unidentified.needsKeepAlive; - } - - private static class HealthState { - private final HttpErrorTracker mismatchErrorTracker = new HttpErrorTracker(5, TimeUnit.MINUTES.toMillis(1)); - - private volatile boolean needsKeepAlive; - private volatile long lastKeepAliveReceived; - } - - /** - * Sends periodic heartbeats/keep-alives over both WebSockets to prevent connection timeouts. If - * either WebSocket fails to get a return heartbeat after {@link SignalWebSocketHealthMonitor#KEEP_ALIVE_TIMEOUT} seconds, both are forced to be recreated. - */ - private class KeepAliveSender extends Thread { - - private volatile boolean shouldKeepRunning = true; - - public void run() { - Log.d(TAG, "[KeepAliveSender] started"); - identified.lastKeepAliveReceived = System.currentTimeMillis(); - unidentified.lastKeepAliveReceived = System.currentTimeMillis(); - - long keepAliveSendTime = System.currentTimeMillis(); - while (shouldKeepRunning && isKeepAliveNecessary()) { - try { - long nextKeepAliveSendTime = (keepAliveSendTime + KEEP_ALIVE_SEND_CADENCE); - sleepUntil(nextKeepAliveSendTime); - - if (shouldKeepRunning && isKeepAliveNecessary()) { - keepAliveSendTime = System.currentTimeMillis(); - signalWebSocket.sendKeepAlive(); - } - - final long responseRequiredTime = keepAliveSendTime + KEEP_ALIVE_TIMEOUT; - sleepUntil(responseRequiredTime); - - if (shouldKeepRunning && isKeepAliveNecessary()) { - if (identified.lastKeepAliveReceived < keepAliveSendTime || unidentified.lastKeepAliveReceived < keepAliveSendTime) { - Log.w(TAG, "Missed keep alives, identified last: " + identified.lastKeepAliveReceived + - " unidentified last: " + unidentified.lastKeepAliveReceived + - " needed by: " + responseRequiredTime); - signalWebSocket.forceNewWebSockets(); - } - } - } catch (Throwable e) { - Log.w(TAG, e); - } - } - Log.d(TAG, "[KeepAliveSender] ended"); - } - - private void sleepUntil(long timeMs) { - while (System.currentTimeMillis() < timeMs) { - long waitTime = timeMs - System.currentTimeMillis(); - if (waitTime > 0) { - try { - sleepTimer.sleep(waitTime); - } catch (InterruptedException e) { - Log.w(TAG, e); - } - } - } - } - - public void shutdown() { - shouldKeepRunning = false; - } - } -} diff --git a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt new file mode 100644 index 0000000000..664682ae73 --- /dev/null +++ b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt @@ -0,0 +1,172 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.thoughtcrime.securesms.net + +import io.reactivex.rxjava3.kotlin.subscribeBy +import io.reactivex.rxjava3.schedulers.Schedulers +import org.signal.core.util.logging.Log +import org.thoughtcrime.securesms.dependencies.AppDependencies +import org.thoughtcrime.securesms.util.TextSecurePreferences +import org.whispersystems.signalservice.api.util.SleepTimer +import org.whispersystems.signalservice.api.websocket.HealthMonitor +import org.whispersystems.signalservice.api.websocket.SignalWebSocket +import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState +import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection +import java.util.concurrent.Executor +import java.util.concurrent.Executors +import kotlin.concurrent.Volatile +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class SignalWebSocketHealthMonitor( + private val sleepTimer: SleepTimer +) : HealthMonitor { + + companion object { + private val TAG = Log.tag(SignalWebSocketHealthMonitor::class) + + /** + * This is the amount of time in between sent keep alives. Must be greater than [KEEP_ALIVE_TIMEOUT] + */ + private val KEEP_ALIVE_SEND_CADENCE: Duration = OkHttpWebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS.seconds + + /** + * This is the amount of time we will wait for a response to the keep alive before we consider the websockets dead. + * It is required that this value be less than [KEEP_ALIVE_SEND_CADENCE] + */ + private val KEEP_ALIVE_TIMEOUT: Duration = 20.seconds + } + + private val executor: Executor = Executors.newSingleThreadExecutor() + + private var webSocket: SignalWebSocket? = null + + private var keepAliveSender: KeepAliveSender? = null + private var needsKeepAlive = false + private var lastKeepAliveReceived: Duration = 0.seconds + + @Suppress("CheckResult") + fun monitor(webSocket: SignalWebSocket) { + executor.execute { + check(this.webSocket == null) + + this.webSocket = webSocket + + webSocket + .state + .subscribeOn(Schedulers.computation()) + .observeOn(Schedulers.computation()) + .distinctUntilChanged() + .subscribeBy { onStateChanged(it) } + + webSocket.keepAliveChangedListener = this::updateKeepAliveSenderStatus + } + } + + private fun onStateChanged(connectionState: WebSocketConnectionState) { + executor.execute { + when (connectionState) { + WebSocketConnectionState.CONNECTED -> { + if (webSocket is SignalWebSocket.AuthenticatedWebSocket) { + TextSecurePreferences.setUnauthorizedReceived(AppDependencies.application, false) + } + } + WebSocketConnectionState.AUTHENTICATION_FAILED -> { + if (webSocket is SignalWebSocket.AuthenticatedWebSocket) { + TextSecurePreferences.setUnauthorizedReceived(AppDependencies.application, true) + } + } + else -> Unit + } + + needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED + + updateKeepAliveSenderStatus() + } + } + + override fun onKeepAliveResponse(sentTimestamp: Long, isIdentifiedWebSocket: Boolean) { + val keepAliveTime = System.currentTimeMillis().milliseconds + executor.execute { + lastKeepAliveReceived = keepAliveTime + } + } + + override fun onMessageError(status: Int, isIdentifiedWebSocket: Boolean) = Unit + + private fun updateKeepAliveSenderStatus() { + if (keepAliveSender == null && sendKeepAlives()) { + keepAliveSender = KeepAliveSender() + keepAliveSender!!.start() + } else if (keepAliveSender != null && !sendKeepAlives()) { + keepAliveSender!!.shutdown() + keepAliveSender = null + } + } + + private fun sendKeepAlives(): Boolean { + return needsKeepAlive && webSocket?.shouldSendKeepAlives == true + } + + /** + * Sends periodic heartbeats/keep-alives over the WebSocket to prevent connection timeouts. If + * the WebSocket fails to get a return heartbeat after [KEEP_ALIVE_TIMEOUT] seconds, it is forced to be recreated. + */ + private inner class KeepAliveSender : Thread() { + + @Volatile + private var shouldKeepRunning = true + + override fun run() { + Log.d(TAG, "[KeepAliveSender($id)] started") + lastKeepAliveReceived = System.currentTimeMillis().milliseconds + + var keepAliveSendTime = System.currentTimeMillis().milliseconds + while (shouldKeepRunning && sendKeepAlives()) { + try { + val nextKeepAliveSendTime: Duration = keepAliveSendTime + KEEP_ALIVE_SEND_CADENCE + sleepUntil(nextKeepAliveSendTime) + + if (shouldKeepRunning && sendKeepAlives()) { + keepAliveSendTime = System.currentTimeMillis().milliseconds + webSocket?.sendKeepAlive() + } + + val responseRequiredTime: Duration = keepAliveSendTime + KEEP_ALIVE_TIMEOUT + sleepUntil(responseRequiredTime) + + if (shouldKeepRunning && sendKeepAlives()) { + if (lastKeepAliveReceived < keepAliveSendTime) { + Log.w(TAG, "Missed keep alive, last: ${lastKeepAliveReceived.inWholeMilliseconds} needed by: ${responseRequiredTime.inWholeMilliseconds}") + webSocket?.forceNewWebSocket() + } + } + } catch (e: Throwable) { + Log.w(TAG, e) + } + } + Log.d(TAG, "[KeepAliveSender($id)] ended") + } + + fun sleepUntil(time: Duration) { + while (System.currentTimeMillis().milliseconds < time) { + val waitTime = time - System.currentTimeMillis().milliseconds + if (waitTime.isPositive()) { + try { + sleepTimer.sleep(waitTime.inWholeMilliseconds) + } catch (e: InterruptedException) { + Log.w(TAG, e) + } + } + } + } + + fun shutdown() { + shouldKeepRunning = false + } + } +} diff --git a/app/src/main/java/org/thoughtcrime/securesms/registration/data/RegistrationRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/registration/data/RegistrationRepository.kt index d3e4a46291..a8d67256d2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registration/data/RegistrationRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registration/data/RegistrationRepository.kt @@ -222,7 +222,7 @@ object RegistrationRepository { SvrRepository.onRegistrationComplete(masterKey, data.pin, hasPin, data.reglockEnabled) AppDependencies.resetNetwork() - AppDependencies.incomingMessageObserver + AppDependencies.startNetwork() PreKeysSyncJob.enqueue() val jobManager = AppDependencies.jobManager diff --git a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/data/RegistrationRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/data/RegistrationRepository.kt index fa58d969eb..accbd065f2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/registrationv3/data/RegistrationRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/registrationv3/data/RegistrationRepository.kt @@ -217,7 +217,7 @@ object RegistrationRepository { SvrRepository.onRegistrationComplete(masterKey, data.pin, hasPin, data.reglockEnabled) AppDependencies.resetNetwork() - AppDependencies.incomingMessageObserver + AppDependencies.startNetwork() PreKeysSyncJob.enqueue() val jobManager = AppDependencies.jobManager diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/SignalProxyUtil.java b/app/src/main/java/org/thoughtcrime/securesms/util/SignalProxyUtil.java index 4ea65f0594..51bb39af85 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/SignalProxyUtil.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/SignalProxyUtil.java @@ -37,12 +37,12 @@ public final class SignalProxyUtil { private SignalProxyUtil() {} public static void startListeningToWebsocket() { - if (SignalStore.proxy().isProxyEnabled() && AppDependencies.getSignalWebSocket().getWebSocketState().firstOrError().blockingGet().isFailure()) { + if (SignalStore.proxy().isProxyEnabled() && AppDependencies.getAuthWebSocket().getState().firstOrError().blockingGet().isFailure()) { Log.w(TAG, "Proxy is in a failed state. Restarting."); AppDependencies.resetNetwork(); } - AppDependencies.getIncomingMessageObserver(); + AppDependencies.startNetwork(); } /** @@ -88,8 +88,8 @@ public final class SignalProxyUtil { return testWebsocketConnectionUnregistered(timeout); } - return AppDependencies.getSignalWebSocket() - .getWebSocketState() + return AppDependencies.getAuthWebSocket() + .getState() .subscribeOn(Schedulers.trampoline()) .observeOn(Schedulers.trampoline()) .timeout(timeout, TimeUnit.MILLISECONDS) diff --git a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt index 78c32085aa..9771245c94 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt +++ b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt @@ -36,7 +36,6 @@ import org.whispersystems.signalservice.api.SignalServiceAccountManager import org.whispersystems.signalservice.api.SignalServiceDataStore import org.whispersystems.signalservice.api.SignalServiceMessageReceiver import org.whispersystems.signalservice.api.SignalServiceMessageSender -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.archive.ArchiveApi import org.whispersystems.signalservice.api.attachment.AttachmentApi import org.whispersystems.signalservice.api.groupsv2.GroupsV2Operations @@ -47,6 +46,7 @@ import org.whispersystems.signalservice.api.services.CallLinksService import org.whispersystems.signalservice.api.services.DonationsService import org.whispersystems.signalservice.api.services.ProfileService import org.whispersystems.signalservice.api.storage.StorageServiceApi +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration import org.whispersystems.signalservice.internal.push.PushServiceSocket import java.util.function.Supplier @@ -64,7 +64,12 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideSignalServiceMessageSender(signalWebSocket: SignalWebSocket, protocolStore: SignalServiceDataStore, pushServiceSocket: PushServiceSocket): SignalServiceMessageSender { + override fun provideSignalServiceMessageSender( + authWebSocket: SignalWebSocket.AuthenticatedWebSocket, + unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket, + protocolStore: SignalServiceDataStore, + pushServiceSocket: PushServiceSocket + ): SignalServiceMessageSender { return mockk(relaxed = true) } @@ -100,7 +105,7 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideIncomingMessageObserver(signalWebSocket: SignalWebSocket): IncomingMessageObserver { + override fun provideIncomingMessageObserver(webSocket: SignalWebSocket.AuthenticatedWebSocket): IncomingMessageObserver { return mockk(relaxed = true) } @@ -156,10 +161,6 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideSignalWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket { - return mockk(relaxed = true) - } - override fun provideProtocolStore(): SignalServiceDataStoreImpl { return mockk(relaxed = true) } @@ -184,7 +185,12 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideProfileService(profileOperations: ClientZkProfileOperations, signalServiceMessageReceiver: SignalServiceMessageReceiver, signalWebSocket: SignalWebSocket): ProfileService { + override fun provideProfileService( + profileOperations: ClientZkProfileOperations, + signalServiceMessageReceiver: SignalServiceMessageReceiver, + authWebSocket: SignalWebSocket.AuthenticatedWebSocket, + unauthWebSocket: SignalWebSocket.UnauthenticatedWebSocket + ): ProfileService { return mockk(relaxed = true) } @@ -216,7 +222,7 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideAttachmentApi(signalWebSocket: SignalWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi { + override fun provideAttachmentApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi { return mockk(relaxed = true) } @@ -231,4 +237,12 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { override fun provideStorageServiceApi(pushServiceSocket: PushServiceSocket): StorageServiceApi { return mockk(relaxed = true) } + + override fun provideAuthWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket.AuthenticatedWebSocket { + return mockk(relaxed = true) + } + + override fun provideUnauthWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket.UnauthenticatedWebSocket { + return mockk(relaxed = true) + } } diff --git a/libsignal-service/build.gradle.kts b/libsignal-service/build.gradle.kts index 28a91b1886..51ba2a5805 100644 --- a/libsignal-service/build.gradle.kts +++ b/libsignal-service/build.gradle.kts @@ -96,6 +96,7 @@ dependencies { implementation(libs.google.jsr305) api(libs.rxjava3.rxjava) + implementation(libs.rxjava3.rxkotlin) implementation(libs.kotlin.stdlib.jdk8) implementation(libs.kotlinx.coroutines.core) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt index 8cd49a6820..bc6249be1b 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt @@ -7,6 +7,7 @@ package org.whispersystems.signalservice.api import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.util.JsonUtil import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage import org.whispersystems.signalservice.internal.websocket.WebsocketResponse @@ -147,6 +148,7 @@ sealed class NetworkResult( * * Useful for bridging to Java, where you may want to use try-catch. */ + @Throws(NonSuccessfulResponseCodeException::class, IOException::class, Throwable::class) fun successOrThrow(): T { when (this) { is Success -> return result diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java index 7ad3f42ec6..c6b47af941 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java @@ -85,6 +85,7 @@ import org.whispersystems.signalservice.api.util.Preconditions; import org.whispersystems.signalservice.api.util.Uint64RangeException; import org.whispersystems.signalservice.api.util.Uint64Util; import org.whispersystems.signalservice.api.util.UuidUtil; +import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException; import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.crypto.AttachmentDigest; @@ -131,7 +132,6 @@ import org.whispersystems.signalservice.internal.util.Util; import java.io.IOException; import java.io.InputStream; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -143,7 +143,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.stream.Collectors; import javax.annotation.Nonnull; @@ -170,7 +169,6 @@ public class SignalServiceMessageSender { private static final int RETRY_COUNT = 4; private final PushServiceSocket socket; - private final SignalWebSocket webSocket; private final SignalServiceAccountDataStore aciStore; private final SignalSessionLock sessionLock; private final SignalServiceAddress localAddress; @@ -182,14 +180,14 @@ public class SignalServiceMessageSender { private final AttachmentService attachmentService; private final MessagingService messagingService; - private final ExecutorService executor; private final Scheduler scheduler; private final long maxEnvelopeSize; public SignalServiceMessageSender(PushServiceSocket pushServiceSocket, SignalServiceDataStore store, SignalSessionLock sessionLock, - SignalWebSocket signalWebSocket, + SignalWebSocket.AuthenticatedWebSocket authWebSocket, + SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket, Optional eventListener, ExecutorService executor, long maxEnvelopeSize) @@ -197,16 +195,14 @@ public class SignalServiceMessageSender { CredentialsProvider credentialsProvider = pushServiceSocket.getCredentialsProvider(); this.socket = pushServiceSocket; - this.webSocket = signalWebSocket; this.aciStore = store.aci(); this.sessionLock = sessionLock; this.localAddress = new SignalServiceAddress(credentialsProvider.getAci(), credentialsProvider.getE164()); this.localDeviceId = credentialsProvider.getDeviceId(); this.localPni = credentialsProvider.getPni(); - this.attachmentService = new AttachmentService(signalWebSocket); - this.messagingService = new MessagingService(signalWebSocket); + this.attachmentService = new AttachmentService(authWebSocket); + this.messagingService = new MessagingService(authWebSocket, unauthWebSocket); this.eventListener = eventListener; - this.executor = executor != null ? executor : Executors.newSingleThreadExecutor(); this.maxEnvelopeSize = maxEnvelopeSize; this.localPniIdentity = store.pni().getIdentityKeyPair(); this.scheduler = Schedulers.from(executor, false, false); @@ -840,7 +836,7 @@ public class SignalServiceMessageSender { } Log.w(TAG, "Failed to retrieve attachment upload attributes using pipe. Falling back..."); } - + if (v4UploadAttributes == null) { Log.d(TAG, "Not using pipe to retrieve attachment upload attributes..."); v4UploadAttributes = socket.getAttachmentV4UploadAttributes(); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java deleted file mode 100644 index a114672775..0000000000 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalWebSocket.java +++ /dev/null @@ -1,372 +0,0 @@ -package org.whispersystems.signalservice.api; - -import org.signal.libsignal.protocol.logging.Log; -import org.whispersystems.signalservice.api.crypto.SealedSenderAccess; -import org.whispersystems.signalservice.api.messages.EnvelopeResponse; -import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState; -import org.whispersystems.signalservice.api.websocket.WebSocketFactory; -import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException; -import org.whispersystems.signalservice.internal.push.Envelope; -import org.whispersystems.signalservice.internal.websocket.WebSocketConnection; -import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage; -import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMessage; -import org.whispersystems.signalservice.internal.websocket.WebsocketResponse; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.TimeoutException; - -import javax.annotation.Nullable; - -import io.reactivex.rxjava3.core.Observable; -import io.reactivex.rxjava3.core.Single; -import io.reactivex.rxjava3.disposables.CompositeDisposable; -import io.reactivex.rxjava3.disposables.Disposable; -import io.reactivex.rxjava3.schedulers.Schedulers; -import io.reactivex.rxjava3.subjects.BehaviorSubject; - -/** - * Provide a general interface to the WebSocket for making requests and reading messages sent by the server. - * Where appropriate, it will handle retrying failed unidentified requests on the regular WebSocket. - */ -public final class SignalWebSocket { - - private static final String TAG = SignalWebSocket.class.getSimpleName(); - - private static final String SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp"; - - private final WebSocketFactory webSocketFactory; - - private WebSocketConnection webSocket; - private final BehaviorSubject webSocketState; - private CompositeDisposable webSocketStateDisposable; - - private WebSocketConnection unidentifiedWebSocket; - private final BehaviorSubject unidentifiedWebSocketState; - private CompositeDisposable unidentifiedWebSocketStateDisposable; - - private boolean canConnect; - - public SignalWebSocket(WebSocketFactory webSocketFactory) { - this.webSocketFactory = webSocketFactory; - this.webSocketState = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED); - this.unidentifiedWebSocketState = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED); - this.webSocketStateDisposable = new CompositeDisposable(); - this.unidentifiedWebSocketStateDisposable = new CompositeDisposable(); - } - - /** - * Get an observable stream of the identified WebSocket state. This observable is valid for the lifetime of - * the instance, and will update as WebSocketConnections are remade. - */ - public Observable getWebSocketState() { - return webSocketState; - } - - /** - * Get an observable stream of the unidentified WebSocket state. This observable is valid for the lifetime of - * the instance, and will update as WebSocketConnections are remade. - */ - public Observable getUnidentifiedWebSocketState() { - return unidentifiedWebSocketState; - } - - /** - * Indicate that WebSocketConnections can now be made and attempt to connect both of them. - */ - public synchronized void connect() { - canConnect = true; - try { - getWebSocket(); - getUnidentifiedWebSocket(); - } catch (WebSocketUnavailableException e) { - throw new AssertionError(e); - } - } - - /** - * Indicate that WebSocketConnections can no longer be made and disconnect both of them. - */ - public synchronized void disconnect() { - canConnect = false; - disconnectIdentified(); - disconnectUnidentified(); - } - - /** - * Indicate that the current WebSocket instances need to be destroyed and new ones should be created the - * next time a connection is required. Intended to be used by the health monitor to cycle a WebSocket. - */ - public synchronized void forceNewWebSockets() { - Log.i(TAG, "Forcing new WebSockets " + - " identified: " + (webSocket != null ? webSocket.getName() : "[null]") + - " unidentified: " + (unidentifiedWebSocket != null ? unidentifiedWebSocket.getName() : "[null]") + - " canConnect: " + canConnect); - - disconnectIdentified(); - disconnectUnidentified(); - } - - private void disconnectIdentified() { - if (webSocket != null) { - webSocketStateDisposable.dispose(); - - webSocket.disconnect(); - webSocket = null; - - //noinspection ConstantConditions - if (!webSocketState.getValue().isFailure()) { - webSocketState.onNext(WebSocketConnectionState.DISCONNECTED); - } - } - } - - private void disconnectUnidentified() { - if (unidentifiedWebSocket != null) { - unidentifiedWebSocketStateDisposable.dispose(); - - unidentifiedWebSocket.disconnect(); - unidentifiedWebSocket = null; - - //noinspection ConstantConditions - if (!unidentifiedWebSocketState.getValue().isFailure()) { - unidentifiedWebSocketState.onNext(WebSocketConnectionState.DISCONNECTED); - } - } - } - - private synchronized WebSocketConnection getWebSocket() throws WebSocketUnavailableException { - if (!canConnect) { - throw new WebSocketUnavailableException(); - } - - if (webSocket == null || webSocket.isDead()) { - webSocketStateDisposable.dispose(); - - webSocket = webSocketFactory.createWebSocket(); - webSocketStateDisposable = new CompositeDisposable(); - - Disposable state = webSocket.connect() - .subscribeOn(Schedulers.computation()) - .observeOn(Schedulers.computation()) - .subscribe(webSocketState::onNext); - webSocketStateDisposable.add(state); - } - return webSocket; - } - - private synchronized WebSocketConnection getUnidentifiedWebSocket() throws WebSocketUnavailableException { - if (!canConnect) { - throw new WebSocketUnavailableException(); - } - - if (unidentifiedWebSocket == null || unidentifiedWebSocket.isDead()) { - unidentifiedWebSocketStateDisposable.dispose(); - - unidentifiedWebSocket = webSocketFactory.createUnidentifiedWebSocket(); - unidentifiedWebSocketStateDisposable = new CompositeDisposable(); - - Disposable state = unidentifiedWebSocket.connect() - .subscribeOn(Schedulers.computation()) - .observeOn(Schedulers.computation()) - .subscribe(unidentifiedWebSocketState::onNext); - unidentifiedWebSocketStateDisposable.add(state); - } - return unidentifiedWebSocket; - } - - /** - * Send keep-alive messages over both WebSocketConnections. - */ - public synchronized void sendKeepAlive() throws IOException { - if (canConnect) { - try { - getWebSocket().sendKeepAlive(); - getUnidentifiedWebSocket().sendKeepAlive(); - } catch (WebSocketUnavailableException e) { - throw new AssertionError(e); - } - } - } - - public Single request(WebSocketRequestMessage requestMessage) { - try { - return getWebSocket().sendRequest(requestMessage); - } catch (IOException e) { - return Single.error(e); - } - } - - public Single request(WebSocketRequestMessage requestMessage, @Nullable SealedSenderAccess sealedSenderAccess) { - if (sealedSenderAccess != null) { - List headers = new ArrayList<>(requestMessage.headers); - headers.add(sealedSenderAccess.getHeader()); - - WebSocketRequestMessage message = requestMessage.newBuilder() - .headers(headers) - .build(); - try { - return getUnidentifiedWebSocket().sendRequest(message) - .flatMap(r -> { - if (r.getStatus() == 401) { - return request(requestMessage, sealedSenderAccess.switchToFallback()); - } - return Single.just(r); - }); - } catch (IOException e) { - return Single.error(e); - } - } else { - return request(requestMessage); - } - } - - /** - * The reads a batch of messages off of the websocket. - * - * Rather than just provide you the batch as a return value, it will invoke the provided callback with the - * batch as an argument. If you are able to successfully process them, this method will then ack all of the - * messages so that they won't be re-delivered in the future. - * - * The return value of this method is a boolean indicating whether or not there are more messages in the - * queue to be read (true if there's still more, or false if you've drained everything). - * - * However, this return value is only really useful the first time you read from the websocket. That's because - * the websocket will only ever let you know if it's drained *once* for any given connection. So if this method - * returns false, a subsequent call while using the same websocket connection will simply block until we either - * get a new message or hit the timeout. - * - * Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will - * not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then - * take any extra messages that are also available up until you've hit your batch size. - */ - @SuppressWarnings("DuplicateThrows") - public boolean readMessageBatch(long timeout, int batchSize, MessageReceivedCallback callback) - throws TimeoutException, WebSocketUnavailableException, IOException - { - List responses = new ArrayList<>(); - boolean hitEndOfQueue = false; - - Optional firstEnvelope = waitForSingleMessage(timeout); - - if (firstEnvelope.isPresent()) { - responses.add(firstEnvelope.get()); - } else { - hitEndOfQueue = true; - } - - if (!hitEndOfQueue) { - for (int i = 1; i < batchSize; i++) { - Optional request = getWebSocket().readRequestIfAvailable(); - - if (request.isPresent()) { - if (isSignalServiceEnvelope(request.get())) { - responses.add(requestToEnvelopeResponse(request.get())); - } else if (isSocketEmptyRequest(request.get())) { - hitEndOfQueue = true; - break; - } - } else { - break; - } - } - } - - if (responses.size() > 0) { - callback.onMessageBatch(responses); - } - - return !hitEndOfQueue; - } - - public void sendAck(EnvelopeResponse response) throws IOException { - getWebSocket().sendResponse(createWebSocketResponse(response.getWebsocketRequest())); - } - - @SuppressWarnings("DuplicateThrows") - private Optional waitForSingleMessage(long timeout) - throws TimeoutException, WebSocketUnavailableException, IOException - { - while (true) { - WebSocketRequestMessage request = getWebSocket().readRequest(timeout); - - if (isSignalServiceEnvelope(request)) { - return Optional.of(requestToEnvelopeResponse(request)); - } else if (isSocketEmptyRequest(request)) { - return Optional.empty(); - } - } - } - - private static EnvelopeResponse requestToEnvelopeResponse(WebSocketRequestMessage request) - throws IOException - { - Optional timestampHeader = findHeader(request); - long timestamp = 0; - - if (timestampHeader.isPresent()) { - try { - timestamp = Long.parseLong(timestampHeader.get()); - } catch (NumberFormatException e) { - Log.w(TAG, "Failed to parse " + SERVER_DELIVERED_TIMESTAMP_HEADER); - } - } - - Envelope envelope = Envelope.ADAPTER.decode(request.body.toByteArray()); - - return new EnvelopeResponse(envelope, timestamp, request); - } - - private static boolean isSignalServiceEnvelope(WebSocketRequestMessage message) { - return "PUT".equals(message.verb) && "/api/v1/message".equals(message.path); - } - - private static boolean isSocketEmptyRequest(WebSocketRequestMessage message) { - return "PUT".equals(message.verb) && "/api/v1/queue/empty".equals(message.path); - } - - private static WebSocketResponseMessage createWebSocketResponse(WebSocketRequestMessage request) { - if (isSignalServiceEnvelope(request)) { - return new WebSocketResponseMessage.Builder() - .id(request.id) - .status(200) - .message("OK") - .build(); - } else { - return new WebSocketResponseMessage.Builder() - .id(request.id) - .status(400) - .message("Unknown") - .build(); - } - } - - private static Optional findHeader(WebSocketRequestMessage message) { - if (message.headers.isEmpty()) { - return Optional.empty(); - } - - for (String header : message.headers) { - if (header.startsWith(SERVER_DELIVERED_TIMESTAMP_HEADER)) { - String[] split = header.split(":"); - if (split.length == 2 && split[0].trim().toLowerCase().equals(SERVER_DELIVERED_TIMESTAMP_HEADER.toLowerCase())) { - return Optional.of(split[1].trim()); - } - } - } - - return Optional.empty(); - } - - /** - * For receiving a callback when a new message has been - * received. - */ - public interface MessageReceivedCallback { - - /** Called with the batch of envelopes. You are responsible for sending acks. **/ - void onMessageBatch(List envelopeResponses); - } -} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentApi.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentApi.kt index 379f33dde4..1838854755 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentApi.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/attachment/AttachmentApi.kt @@ -6,10 +6,10 @@ package org.whispersystems.signalservice.api.attachment import org.whispersystems.signalservice.api.NetworkResult -import org.whispersystems.signalservice.api.SignalWebSocket import org.whispersystems.signalservice.api.crypto.AttachmentCipherStreamUtil import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentRemoteId import org.whispersystems.signalservice.api.messages.SignalServiceAttachmentStream +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.crypto.PaddingInputStream import org.whispersystems.signalservice.internal.push.AttachmentUploadForm import org.whispersystems.signalservice.internal.push.PushAttachmentData @@ -25,13 +25,13 @@ import kotlin.jvm.optionals.getOrNull * Class to interact with various attachment-related endpoints. */ class AttachmentApi( - private val signalWebSocket: SignalWebSocket, + private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket, private val pushServiceSocket: PushServiceSocket ) { companion object { @JvmStatic - fun create(signalWebSocket: SignalWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi { - return AttachmentApi(signalWebSocket, pushServiceSocket) + fun create(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi { + return AttachmentApi(authWebSocket, pushServiceSocket) } } @@ -46,7 +46,7 @@ class AttachmentApi( ) return NetworkResult - .fromWebSocketRequest(signalWebSocket, request, AttachmentUploadForm::class) + .fromWebSocketRequest(authWebSocket, request, AttachmentUploadForm::class) .fallbackToFetch( unless = { it is NetworkResult.StatusCodeError && it.code == 209 }, fallback = { pushServiceSocket.attachmentV4UploadAttributes } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/AttachmentService.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/AttachmentService.kt index 23d09a394a..0f276a3133 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/AttachmentService.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/AttachmentService.kt @@ -1,7 +1,7 @@ package org.whispersystems.signalservice.api.services import io.reactivex.rxjava3.core.Single -import org.whispersystems.signalservice.api.SignalWebSocket +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.ServiceResponse import org.whispersystems.signalservice.internal.ServiceResponseProcessor import org.whispersystems.signalservice.internal.push.AttachmentUploadForm @@ -15,7 +15,7 @@ import java.security.SecureRandom * * Note: To be expanded to have REST fallback and other attachment related operations. */ -class AttachmentService(private val signalWebSocket: SignalWebSocket) { +class AttachmentService(private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket) { fun getAttachmentV4UploadAttributes(): Single> { val requestMessage = WebSocketRequestMessage( id = SecureRandom().nextLong(), @@ -23,7 +23,7 @@ class AttachmentService(private val signalWebSocket: SignalWebSocket) { path = "/v4/attachments/form/upload" ) - return signalWebSocket.request(requestMessage) + return authWebSocket.request(requestMessage) .map { response: WebsocketResponse? -> DefaultResponseMapper.getDefault(AttachmentUploadForm::class.java).map(response) } .onErrorReturn { throwable: Throwable? -> ServiceResponse.forUnknownError(throwable) } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/MessagingService.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/MessagingService.java index 2f4ca3d6cd..5abb4502f1 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/MessagingService.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/MessagingService.java @@ -1,9 +1,9 @@ package org.whispersystems.signalservice.api.services; -import org.whispersystems.signalservice.api.SignalWebSocket; import org.whispersystems.signalservice.api.crypto.SealedSenderAccess; import org.whispersystems.signalservice.api.push.exceptions.NotFoundException; import org.whispersystems.signalservice.api.push.exceptions.UnregisteredUserException; +import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.ServiceResponseProcessor; import org.whispersystems.signalservice.internal.push.GroupMismatchedDevices; @@ -37,10 +37,12 @@ import okio.ByteString; * Note: To be expanded to have REST fallback and other messaging related operations. */ public class MessagingService { - private final SignalWebSocket signalWebSocket; + private final SignalWebSocket.AuthenticatedWebSocket authWebSocket; + private final SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket; - public MessagingService(SignalWebSocket signalWebSocket) { - this.signalWebSocket = signalWebSocket; + public MessagingService(SignalWebSocket.AuthenticatedWebSocket authWebSocket, SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket) { + this.authWebSocket = authWebSocket; + this.unauthWebSocket = unauthWebSocket; } public Single> send(OutgoingPushMessageList list, @@ -69,9 +71,22 @@ public class MessagingService { .withCustomError(404, (status, body, getHeader) -> new UnregisteredUserException(list.getDestination(), new NotFoundException("not found"))) .build(); - return signalWebSocket.request(requestMessage, sealedSenderAccess) + if (sealedSenderAccess == null) { + return authWebSocket.request(requestMessage) .map(responseMapper::map) .onErrorReturn(ServiceResponse::forUnknownError); + } else { + return unauthWebSocket.request(requestMessage, sealedSenderAccess) + .flatMap(response -> { + if (response.getStatus() == 401) { + return authWebSocket.request(requestMessage); + } else { + return Single.just(response); + } + }) + .map(responseMapper::map) + .onErrorReturn(ServiceResponse::forUnknownError); + } } public Single> sendToGroup(byte[] body, @Nonnull SealedSenderAccess sealedSenderAccess, long timestamp, boolean online, boolean urgent, boolean story) { @@ -90,7 +105,7 @@ public class MessagingService { .body(ByteString.of(body)) .build(); - return signalWebSocket.request(requestMessage) + return unauthWebSocket.request(requestMessage) .map(DefaultResponseMapper.extend(SendGroupMessageResponse.class) .withCustomError(401, (status, errorBody, getHeader) -> new InvalidUnidentifiedAccessHeaderException()) .withCustomError(404, (status, errorBody, getHeader) -> new NotFoundException("At least one unregistered user in message send.")) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java index 64e18d595f..efd696082d 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/services/ProfileService.java @@ -11,9 +11,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest; import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext; import org.signal.libsignal.zkgroup.profiles.ProfileKeyVersion; import org.whispersystems.signalservice.api.SignalServiceMessageReceiver; -import org.whispersystems.signalservice.api.SignalWebSocket; import org.whispersystems.signalservice.api.crypto.SealedSenderAccess; -import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess; import org.whispersystems.signalservice.api.profiles.ProfileAndCredential; import org.whispersystems.signalservice.api.profiles.SignalServiceProfile; import org.whispersystems.signalservice.api.push.ServiceId; @@ -21,6 +19,7 @@ import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.exceptions.AuthorizationFailedException; import org.whispersystems.signalservice.api.push.exceptions.MalformedResponseException; +import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.internal.ServiceResponse; import org.whispersystems.signalservice.internal.ServiceResponseProcessor; import org.whispersystems.signalservice.internal.push.IdentityCheckRequest; @@ -59,17 +58,20 @@ public final class ProfileService { private static final String TAG = ProfileService.class.getSimpleName(); - private final ClientZkProfileOperations clientZkProfileOperations; - private final SignalServiceMessageReceiver receiver; - private final SignalWebSocket signalWebSocket; + private final ClientZkProfileOperations clientZkProfileOperations; + private final SignalServiceMessageReceiver receiver; + private final SignalWebSocket.AuthenticatedWebSocket authWebSocket; + private final SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket; public ProfileService(ClientZkProfileOperations clientZkProfileOperations, SignalServiceMessageReceiver receiver, - SignalWebSocket signalWebSocket) + SignalWebSocket.AuthenticatedWebSocket authWebSocket, + SignalWebSocket.UnauthenticatedWebSocket unauthWebSocket) { this.clientZkProfileOperations = clientZkProfileOperations; this.receiver = receiver; - this.signalWebSocket = signalWebSocket; + this.authWebSocket = authWebSocket; + this.unauthWebSocket = unauthWebSocket; } public Single> getProfile(@Nonnull SignalServiceAddress address, @@ -118,10 +120,24 @@ public final class ProfileService { .withResponseMapper(new ProfileResponseMapper(requestType, requestContext)) .build(); - return signalWebSocket.request(requestMessage, sealedSenderAccess) + if (sealedSenderAccess == null) { + return authWebSocket.request(requestMessage) .map(responseMapper::map) - .onErrorResumeNext(t -> getProfileRestFallback(address, profileKey, sealedSenderAccess, requestType, locale)) + .onErrorResumeNext(t -> getProfileRestFallback(address, profileKey, null, requestType, locale)) .onErrorReturn(ServiceResponse::forUnknownError); + } else { + return unauthWebSocket.request(requestMessage, sealedSenderAccess) + .flatMap(response -> { + if (response.getStatus() == 401) { + return authWebSocket.request(requestMessage); + } else { + return Single.just(response); + } + }) + .map(responseMapper::map) + .onErrorResumeNext(t -> getProfileRestFallback(address, profileKey, sealedSenderAccess, requestType, locale)) + .onErrorReturn(ServiceResponse::forUnknownError); + } } public @NonNull Single> performIdentityCheck(@Nonnull Map serviceIdIdentityKeyMap) { @@ -141,7 +157,7 @@ public final class ProfileService { ResponseMapper responseMapper = DefaultResponseMapper.getDefault(IdentityCheckResponse.class); - return signalWebSocket.request(builder.build(), SealedSenderAccess.NONE) + return unauthWebSocket.request(builder.build()) .map(responseMapper::map) .onErrorResumeNext(t -> performIdentityCheckRestFallback(request, responseMapper)) .onErrorReturn(ServiceResponse::forUnknownError); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt new file mode 100644 index 0000000000..f78837cc1d --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt @@ -0,0 +1,307 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.api.websocket + +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.disposables.CompositeDisposable +import io.reactivex.rxjava3.kotlin.addTo +import io.reactivex.rxjava3.kotlin.subscribeBy +import io.reactivex.rxjava3.schedulers.Schedulers +import io.reactivex.rxjava3.subjects.BehaviorSubject +import org.signal.core.util.logging.Log +import org.signal.core.util.orNull +import org.whispersystems.signalservice.api.crypto.SealedSenderAccess +import org.whispersystems.signalservice.api.messages.EnvelopeResponse +import org.whispersystems.signalservice.internal.push.Envelope +import org.whispersystems.signalservice.internal.websocket.WebSocketConnection +import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage +import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMessage +import org.whispersystems.signalservice.internal.websocket.WebsocketResponse +import java.io.IOException +import java.util.concurrent.TimeoutException + +/** + * Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket + * interactions. + */ +sealed class SignalWebSocket( + private val createConnection: () -> WebSocketConnection +) { + + companion object { + private val TAG = Log.tag(SignalWebSocket::class) + + const val SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp" + } + + private var connection: WebSocketConnection? = null + private val _state: BehaviorSubject = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) + protected var disposable: CompositeDisposable = CompositeDisposable() + + private var canConnect = false + + var shouldSendKeepAlives: Boolean = true + set(value) { + field = value + keepAliveChangedListener?.invoke() + } + var keepAliveChangedListener: (() -> Unit)? = null + + val state: Observable = _state + + /** + * Indicate that WebSocketConnection can now be made and attempt to connect. + */ + @Synchronized + fun connect() { + canConnect = true + getWebSocket() + } + + /** + * Indicate that WebSocketConnection can no longer be made and disconnect. + */ + @Synchronized + fun disconnect() { + canConnect = false + disconnectInternal() + } + + private fun disconnectInternal() { + if (connection != null) { + disposable.dispose() + + connection!!.disconnect() + connection = null + + if (!_state.value!!.isFailure) { + _state.onNext(WebSocketConnectionState.DISCONNECTED) + } + } + } + + @Synchronized + @Throws(IOException::class) + fun sendKeepAlive() { + if (canConnect) { + getWebSocket().sendKeepAlive() + } + } + + fun request(request: WebSocketRequestMessage): Single { + return try { + getWebSocket().sendRequest(request) + } catch (e: IOException) { + Single.error(e) + } + } + + @Throws(IOException::class) + fun sendAck(response: EnvelopeResponse) { + getWebSocket().sendResponse(response.websocketRequest.getWebSocketResponse()) + } + + @Synchronized + @Throws(WebSocketUnavailableException::class) + protected fun getWebSocket(): WebSocketConnection { + if (!canConnect) { + throw WebSocketUnavailableException() + } + + if (connection == null || connection?.isDead() == true) { + disposable.dispose() + + disposable = CompositeDisposable() + val newConnection = createConnection() + + newConnection + .connect() + .subscribeOn(Schedulers.io()) + .observeOn(Schedulers.io()) + .subscribeBy { _state.onNext(it) } + .addTo(disposable) + + this.connection = newConnection + } + + return connection!! + } + + @Synchronized + fun forceNewWebSocket() { + Log.i(TAG, "Forcing new WebSockets connection: ${connection?.name ?: "[null]"} canConnect: $canConnect") + disconnectInternal() + } + + protected fun WebSocketRequestMessage.isSignalServiceEnvelope(): Boolean { + return "PUT" == this.verb && "/api/v1/message" == this.path + } + + protected fun WebSocketRequestMessage.isSocketEmptyRequest(): Boolean { + return "PUT" == this.verb && "/api/v1/queue/empty" == this.path + } + + private fun WebSocketRequestMessage.getWebSocketResponse(): WebSocketResponseMessage { + return if (this.isSignalServiceEnvelope()) { + WebSocketResponseMessage.Builder() + .id(this.id) + .status(200) + .message("OK") + .build() + } else { + WebSocketResponseMessage.Builder() + .id(this.id) + .status(400) + .message("Unknown") + .build() + } + } + + /** + * WebSocket type for communicating with the server without authenticating. Also known as "unidentified". + */ + class UnauthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) { + fun request(requestMessage: WebSocketRequestMessage, sealedSenderAccess: SealedSenderAccess): Single { + val headers: MutableList = requestMessage.headers.toMutableList() + headers.add(sealedSenderAccess.header) + + val message = requestMessage + .newBuilder() + .headers(headers) + .build() + + try { + return getWebSocket() + .sendRequest(message) + .flatMap { response -> + if (response.status == 401) { + val fallback = sealedSenderAccess.switchToFallback() + if (fallback != null) { + return@flatMap request(requestMessage, fallback) + } + } + Single.just(response) + } + } catch (e: IOException) { + return Single.error(e) + } + } + } + + /** + * WebSocket type for communicating with the server with authentication. Also known as "identified". + */ + class AuthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) { + + /** + * The reads a batch of messages off of the websocket. + * + * Rather than just provide you the batch as a return value, it will invoke the provided callback with the + * batch as an argument. If you are able to successfully process them, this method will then ack all of the + * messages so that they won't be re-delivered in the future. + * + * The return value of this method is a boolean indicating whether or not there are more messages in the + * queue to be read (true if there's still more, or false if you've drained everything). + * + * However, this return value is only really useful the first time you read from the websocket. That's because + * the websocket will only ever let you know if it's drained *once* for any given connection. So if this method + * returns false, a subsequent call while using the same websocket connection will simply block until we either + * get a new message or hit the timeout. + * + * Concerning the requested batch size, it's worth noting that this is simply an upper bound. This method will + * not wait extra time until the batch has "filled up". Instead, it will wait for a single message, and then + * take any extra messages that are also available up until you've hit your batch size. + */ + @Throws(TimeoutException::class, WebSocketUnavailableException::class, IOException::class) + fun readMessageBatch(timeout: Long, batchSize: Int, callback: MessageReceivedCallback): Boolean { + val responses: MutableList = ArrayList() + var hitEndOfQueue = false + + val firstEnvelope: EnvelopeResponse? = waitForSingleMessage(timeout) + + if (firstEnvelope != null) { + responses.add(firstEnvelope) + } else { + hitEndOfQueue = true + } + + if (!hitEndOfQueue) { + for (i in 1 until batchSize) { + val request = getWebSocket().readRequestIfAvailable().orNull() + + if (request != null) { + if (request.isSignalServiceEnvelope()) { + responses.add(request.toEnvelopeResponse()) + } else if (request.isSocketEmptyRequest()) { + hitEndOfQueue = true + break + } + } else { + break + } + } + } + + if (responses.size > 0) { + callback.onMessageBatch(responses) + } + + return !hitEndOfQueue + } + + @Throws(TimeoutException::class, WebSocketUnavailableException::class, IOException::class) + private fun waitForSingleMessage(timeout: Long): EnvelopeResponse? { + while (true) { + val request = getWebSocket().readRequest(timeout) + + if (request.isSignalServiceEnvelope()) { + return request.toEnvelopeResponse() + } else if (request.isSocketEmptyRequest()) { + return null + } + } + } + + @Throws(IOException::class) + private fun WebSocketRequestMessage.toEnvelopeResponse(): EnvelopeResponse { + val timestamp = this.findHeader() + + if (timestamp == null) { + Log.w(TAG, "Failed to parse $SERVER_DELIVERED_TIMESTAMP_HEADER") + } + + val envelope = Envelope.ADAPTER.decode(this.body!!.toByteArray()) + + return EnvelopeResponse(envelope, timestamp ?: 0, this) + } + + private fun WebSocketRequestMessage.findHeader(): Long? { + if (this.headers.isEmpty()) { + return null + } + + return this.headers + .asSequence() + .filter { it.startsWith(SERVER_DELIVERED_TIMESTAMP_HEADER) } + .map { it.split(":") } + .filter { it.size == 2 && it[0].trim().lowercase() == SERVER_DELIVERED_TIMESTAMP_HEADER.lowercase() } + .map { it[1].trim() } + .filter { it.isNotEmpty() } + .firstOrNull() + ?.toLongOrNull() + } + + /** + * For receiving a callback when a new message has been + * received. + */ + fun interface MessageReceivedCallback { + /** Called with the batch of envelopes. You are responsible for sending acks. */ + fun onMessageBatch(envelopeResponses: List) + } + } +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java index 5780202a1f..cbf771f642 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java @@ -11,7 +11,6 @@ import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState; import org.whispersystems.signalservice.internal.configuration.SignalProxy; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.configuration.SignalServiceUrl; -import org.whispersystems.signalservice.internal.push.AuthCredentials; import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager; import org.whispersystems.signalservice.internal.util.Util;