diff --git a/app/src/main/java/org/thoughtcrime/securesms/ContactSelectionListFragment.java b/app/src/main/java/org/thoughtcrime/securesms/ContactSelectionListFragment.java index 34d38114d3..9543f856c0 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/ContactSelectionListFragment.java +++ b/app/src/main/java/org/thoughtcrime/securesms/ContactSelectionListFragment.java @@ -47,8 +47,8 @@ import androidx.transition.TransitionManager; import com.google.android.material.dialog.MaterialAlertDialogBuilder; +import org.signal.core.util.concurrent.JvmRxExtensions; import org.signal.core.util.concurrent.LifecycleDisposable; -import org.signal.core.util.concurrent.RxExtensions; import org.signal.core.util.concurrent.SimpleTask; import org.signal.core.util.logging.Log; import org.thoughtcrime.securesms.calls.YouAreAlreadyInACallSnackbar; @@ -722,7 +722,7 @@ public final class ContactSelectionListFragment extends LoggingFragment { SimpleTask.run(getViewLifecycleOwner().getLifecycle(), () -> { try { - return RxExtensions.safeBlockingGet(UsernameRepository.fetchAciForUsername(UsernameUtil.sanitizeUsernameFromSearch(username))); + return JvmRxExtensions.safeBlockingGet(UsernameRepository.fetchAciForUsername(UsernameUtil.sanitizeUsernameFromSearch(username))); } catch (InterruptedException e) { Log.w(TAG, "Interrupted?", e); return UsernameAciFetchResult.NetworkError.INSTANCE; diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt index 7792a5c9b2..316cd51d4a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/AppDependencies.kt @@ -377,7 +377,7 @@ object AppDependencies { fun provideArchiveApi(pushServiceSocket: PushServiceSocket): ArchiveApi fun provideKeysApi(pushServiceSocket: PushServiceSocket): KeysApi fun provideAttachmentApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi - fun provideLinkDeviceApi(pushServiceSocket: PushServiceSocket): LinkDeviceApi + fun provideLinkDeviceApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket): LinkDeviceApi fun provideRegistrationApi(pushServiceSocket: PushServiceSocket): RegistrationApi fun provideStorageServiceApi(pushServiceSocket: PushServiceSocket): StorageServiceApi fun provideAuthWebSocket(signalServiceConfigurationSupplier: Supplier, libSignalNetworkSupplier: Supplier): SignalWebSocket.AuthenticatedWebSocket diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 85d4282dc1..35cbe11d58 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -478,8 +478,8 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { } @Override - public @NonNull LinkDeviceApi provideLinkDeviceApi(@NonNull PushServiceSocket pushServiceSocket) { - return new LinkDeviceApi(pushServiceSocket); + public @NonNull LinkDeviceApi provideLinkDeviceApi(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket) { + return new LinkDeviceApi(authWebSocket); } @Override diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt index 572de3043c..d69535a151 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/NetworkDependenciesModule.kt @@ -146,7 +146,7 @@ class NetworkDependenciesModule( } val linkDeviceApi: LinkDeviceApi by lazy { - provider.provideLinkDeviceApi(pushServiceSocket) + provider.provideLinkDeviceApi(authWebSocket) } val registrationApi: RegistrationApi by lazy { diff --git a/app/src/main/java/org/thoughtcrime/securesms/jobs/LinkedDeviceInactiveCheckJob.kt b/app/src/main/java/org/thoughtcrime/securesms/jobs/LinkedDeviceInactiveCheckJob.kt index 0078147715..d821e99e95 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/jobs/LinkedDeviceInactiveCheckJob.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/jobs/LinkedDeviceInactiveCheckJob.kt @@ -68,7 +68,10 @@ class LinkedDeviceInactiveCheckJob private constructor( } val devices = try { - AppDependencies.signalServiceAccountManager.devices + AppDependencies + .linkDeviceApi + .getDevices() + .successOrThrow() .filter { it.id != SignalServiceAddress.DEFAULT_DEVICE_ID } } catch (e: IOException) { return Result.retry(defaultBackoff()) diff --git a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt index 71ac6b5945..79aeffb401 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt @@ -45,29 +45,32 @@ object LinkDeviceRepository { private val TAG = Log.tag(LinkDeviceRepository::class) fun removeDevice(deviceId: Int): Boolean { - return try { - val accountManager = AppDependencies.signalServiceAccountManager - accountManager.removeDevice(deviceId) - LinkedDeviceInactiveCheckJob.enqueue() - true - } catch (e: IOException) { - Log.w(TAG, e) - false + return when (val result = AppDependencies.linkDeviceApi.removeDevice(deviceId)) { + is NetworkResult.Success -> { + LinkedDeviceInactiveCheckJob.enqueue() + true + } + else -> { + Log.w(TAG, "Unable to remove device", result.getCause()) + false + } } } fun loadDevices(): List? { - val accountManager = AppDependencies.signalServiceAccountManager - return try { - val devices: List = accountManager.getDevices() - .filter { d: DeviceInfo -> d.getId() != SignalServiceAddress.DEFAULT_DEVICE_ID } - .map { deviceInfo: DeviceInfo -> deviceInfo.toDevice() } - .sortedBy { it.createdMillis } - .toList() - devices - } catch (e: IOException) { - Log.w(TAG, e) - null + return when (val result = AppDependencies.linkDeviceApi.getDevices()) { + is NetworkResult.Success -> { + result + .result + .filter { d: DeviceInfo -> d.getId() != SignalServiceAddress.DEFAULT_DEVICE_ID } + .map { deviceInfo: DeviceInfo -> deviceInfo.toDevice() } + .sortedBy { it.createdMillis } + .toList() + } + else -> { + Log.w(TAG, "Unable to load device", result.getCause()) + null + } } } @@ -132,12 +135,12 @@ object LinkDeviceRepository { val verificationCodeResult: LinkedDeviceVerificationCodeResponse = when (val result = SignalNetwork.linkDevice.getDeviceVerificationCode()) { is NetworkResult.Success -> result.result is NetworkResult.ApplicationError -> throw result.throwable - is NetworkResult.NetworkError -> return LinkDeviceResult.NetworkError + is NetworkResult.NetworkError -> return LinkDeviceResult.NetworkError(result.exception) is NetworkResult.StatusCodeError -> { return when (result.code) { 411 -> LinkDeviceResult.LimitExceeded - 429 -> LinkDeviceResult.NetworkError - else -> LinkDeviceResult.NetworkError + 429 -> LinkDeviceResult.NetworkError(result.exception) + else -> LinkDeviceResult.NetworkError(result.exception) } } } @@ -171,15 +174,15 @@ object LinkDeviceRepository { LinkDeviceResult.Success(verificationCodeResult.tokenIdentifier) } is NetworkResult.ApplicationError -> throw deviceLinkResult.throwable - is NetworkResult.NetworkError -> LinkDeviceResult.NetworkError + is NetworkResult.NetworkError -> LinkDeviceResult.NetworkError(deviceLinkResult.exception) is NetworkResult.StatusCodeError -> { when (deviceLinkResult.code) { 403 -> LinkDeviceResult.NoDevice 409 -> LinkDeviceResult.NoDevice 411 -> LinkDeviceResult.LimitExceeded - 422 -> LinkDeviceResult.NetworkError - 429 -> LinkDeviceResult.NetworkError - else -> LinkDeviceResult.NetworkError + 422 -> LinkDeviceResult.NetworkError(deviceLinkResult.exception) + 429 -> LinkDeviceResult.NetworkError(deviceLinkResult.exception) + else -> LinkDeviceResult.NetworkError(deviceLinkResult.exception) } } } @@ -200,7 +203,7 @@ object LinkDeviceRepository { Log.d(TAG, "[waitForDeviceToBeLinked] Willing to wait for $timeRemaining ms...") val result = SignalNetwork.linkDevice.waitForLinkedDevice( token = token, - timeoutSeconds = timeRemaining.milliseconds.inWholeSeconds.toInt() + timeout = timeRemaining.milliseconds ) when (result) { @@ -422,7 +425,7 @@ object LinkDeviceRepository { data object None : LinkDeviceResult data class Success(val token: String) : LinkDeviceResult data object NoDevice : LinkDeviceResult - data object NetworkError : LinkDeviceResult + data class NetworkError(val error: Throwable) : LinkDeviceResult data object KeyError : LinkDeviceResult data object LimitExceeded : LinkDeviceResult data object BadCode : LinkDeviceResult diff --git a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceViewModel.kt b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceViewModel.kt index ecc9212fe4..6cf02e3b67 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceViewModel.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceViewModel.kt @@ -288,7 +288,7 @@ class LinkDeviceViewModel : ViewModel() { Log.d(TAG, "[addDeviceWithSync] Got result: $result") if (result !is LinkDeviceResult.Success) { - Log.w(TAG, "[addDeviceWithSync] Unable to link device $result") + Log.w(TAG, "[addDeviceWithSync] Unable to link device $result", if (result is LinkDeviceResult.NetworkError) result.error else null) _state.update { it.copy( dialogState = DialogState.None @@ -377,7 +377,7 @@ class LinkDeviceViewModel : ViewModel() { } if (result !is LinkDeviceResult.Success) { - Log.w(TAG, "Unable to link device $result") + Log.w(TAG, "Unable to link device $result", if (result is LinkDeviceResult.NetworkError) result.error else null) _state.update { it.copy( dialogState = DialogState.None diff --git a/app/src/main/java/org/thoughtcrime/securesms/util/CommunicationActions.java b/app/src/main/java/org/thoughtcrime/securesms/util/CommunicationActions.java index 059e4c3350..d3d502f480 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/util/CommunicationActions.java +++ b/app/src/main/java/org/thoughtcrime/securesms/util/CommunicationActions.java @@ -23,7 +23,7 @@ import androidx.fragment.app.FragmentManager; import com.google.android.material.dialog.MaterialAlertDialogBuilder; -import org.signal.core.util.concurrent.RxExtensions; +import org.signal.core.util.concurrent.JvmRxExtensions; import org.signal.core.util.concurrent.SignalExecutors; import org.signal.core.util.concurrent.SimpleTask; import org.signal.core.util.logging.Log; @@ -460,7 +460,7 @@ public class CommunicationActions { SimpleTask.run(() -> { try { - UsernameLinkConversionResult result = RxExtensions.safeBlockingGet(UsernameRepository.fetchUsernameAndAciFromLink(link)); + UsernameLinkConversionResult result = JvmRxExtensions.safeBlockingGet(UsernameRepository.fetchUsernameAndAciFromLink(link)); // TODO we could be better here and report different types of errors to the UI if (result instanceof UsernameLinkConversionResult.Success success) { diff --git a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt index 9771245c94..8a448e45f4 100644 --- a/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt +++ b/app/src/test/java/org/thoughtcrime/securesms/dependencies/MockApplicationDependencyProvider.kt @@ -226,7 +226,7 @@ class MockApplicationDependencyProvider : AppDependencies.Provider { return mockk(relaxed = true) } - override fun provideLinkDeviceApi(pushServiceSocket: PushServiceSocket): LinkDeviceApi { + override fun provideLinkDeviceApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket): LinkDeviceApi { return mockk(relaxed = true) } diff --git a/core-util-jvm/build.gradle.kts b/core-util-jvm/build.gradle.kts index 4e8b0af1b2..ee4cf31a75 100644 --- a/core-util-jvm/build.gradle.kts +++ b/core-util-jvm/build.gradle.kts @@ -54,6 +54,8 @@ dependencies { implementation(libs.kotlinx.coroutines.core) implementation(libs.kotlinx.coroutines.core.jvm) implementation(libs.google.libphonenumber) + implementation(libs.rxjava3.rxjava) + implementation(libs.rxjava3.rxkotlin) testImplementation(testLibs.junit.junit) testImplementation(testLibs.assertk) diff --git a/core-util-jvm/src/main/java/org/signal/core/util/concurrent/JvmRxExtensions.kt b/core-util-jvm/src/main/java/org/signal/core/util/concurrent/JvmRxExtensions.kt new file mode 100644 index 0000000000..a5c370a47a --- /dev/null +++ b/core-util-jvm/src/main/java/org/signal/core/util/concurrent/JvmRxExtensions.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +@file:JvmName("JvmRxExtensions") + +package org.signal.core.util.concurrent + +import io.reactivex.rxjava3.core.Single + +/** + * Throw an [InterruptedException] if a [Single.blockingGet] call is interrupted. This can + * happen when being called by code already within an Rx chain that is disposed. + * + * [Single.blockingGet] is considered harmful and should not be used. + */ +@Throws(InterruptedException::class) +fun Single.safeBlockingGet(): T { + try { + return blockingGet() + } catch (e: RuntimeException) { + val cause = e.cause + if (cause is InterruptedException) { + throw cause + } else { + throw e + } + } +} diff --git a/core-util/src/main/java/org/signal/core/util/concurrent/RxExtensions.kt b/core-util/src/main/java/org/signal/core/util/concurrent/RxExtensions.kt index 2d2ad45552..c1aabad534 100644 --- a/core-util/src/main/java/org/signal/core/util/concurrent/RxExtensions.kt +++ b/core-util/src/main/java/org/signal/core/util/concurrent/RxExtensions.kt @@ -1,8 +1,12 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + @file:JvmName("RxExtensions") package org.signal.core.util.concurrent -import android.annotation.SuppressLint import androidx.lifecycle.LifecycleOwner import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Flowable @@ -13,27 +17,6 @@ import io.reactivex.rxjava3.kotlin.addTo import io.reactivex.rxjava3.kotlin.subscribeBy import io.reactivex.rxjava3.subjects.Subject -/** - * Throw an [InterruptedException] if a [Single.blockingGet] call is interrupted. This can - * happen when being called by code already within an Rx chain that is disposed. - * - * [Single.blockingGet] is considered harmful and should not be used. - */ -@SuppressLint("UnsafeBlockingGet") -@Throws(InterruptedException::class) -fun Single.safeBlockingGet(): T { - try { - return blockingGet() - } catch (e: RuntimeException) { - val cause = e.cause - if (cause is InterruptedException) { - throw cause - } else { - throw e - } - } -} - fun Flowable.observe(viewLifecycleOwner: LifecycleOwner, onNext: (T) -> Unit) { val lifecycleDisposable = LifecycleDisposable() lifecycleDisposable.bindTo(viewLifecycleOwner) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt index bc6249be1b..99046ddb17 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/NetworkResult.kt @@ -5,15 +5,21 @@ package org.whispersystems.signalservice.api +import org.signal.core.util.concurrent.safeBlockingGet +import org.whispersystems.signalservice.api.NetworkResult.StatusCodeError +import org.whispersystems.signalservice.api.NetworkResult.Success import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.util.JsonUtil +import org.whispersystems.signalservice.internal.websocket.WebSocketConnection import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage import org.whispersystems.signalservice.internal.websocket.WebsocketResponse import java.io.IOException import java.util.concurrent.TimeoutException import kotlin.reflect.KClass +import kotlin.reflect.cast +import kotlin.time.Duration typealias StatusCodeErrorAction = (NetworkResult.StatusCodeError<*>) -> Unit @@ -51,49 +57,64 @@ sealed class NetworkResult( ApplicationError(e) } + /** + * A convenience method to convert a websocket request into a network result. + * Common HTTP errors will be translated to [StatusCodeError]s. + */ + @JvmStatic + fun fromWebSocketRequest( + signalWebSocket: SignalWebSocket, + request: WebSocketRequestMessage + ): NetworkResult = fromWebSocketRequest( + signalWebSocket = signalWebSocket, + request = request, + clazz = Unit::class + ) + /** * A convenience method to convert a websocket request into a network result with simple conversion of the response body to the desired class. - * Common exceptions will be caught and translated to errors. + * Common HTTP errors will be translated to [StatusCodeError]s. */ @JvmStatic fun fromWebSocketRequest( signalWebSocket: SignalWebSocket, request: WebSocketRequestMessage, - clazz: KClass + clazz: KClass, + timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT + ): NetworkResult { + return fromWebSocketRequest( + signalWebSocket = signalWebSocket, + request = request, + timeout = timeout, + webSocketResponseConverter = DefaultWebSocketConverter(clazz) + ) + } + + /** + * A convenience method to convert a websocket request into a network result with the ability to fully customize the conversion of the response. + * Common HTTP errors will be translated to [StatusCodeError]s. + */ + @JvmStatic + fun fromWebSocketRequest( + signalWebSocket: SignalWebSocket, + request: WebSocketRequestMessage, + timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT, + webSocketResponseConverter: WebSocketResponseConverter ): NetworkResult = try { - val result: Result = signalWebSocket.request(request) - .map { response: WebsocketResponse -> Result.success(JsonUtil.fromJson(response.body, clazz.java)) } - .onErrorReturn { Result.failure(it) } - .blockingGet() - Success(result.getOrThrow()) + val result: Result> = signalWebSocket.request(request, timeout) + .map { response: WebsocketResponse -> Result.success(webSocketResponseConverter.convert(response)) } + .onErrorReturn { Result.failure(it) } + .safeBlockingGet() + + result.getOrThrow() } catch (e: NonSuccessfulResponseCodeException) { StatusCodeError(e) } catch (e: IOException) { NetworkError(e) } catch (e: TimeoutException) { NetworkError(PushNetworkException(e)) - } catch (e: Throwable) { - ApplicationError(e) - } - - /** - * A convenience method to convert a websocket request into a network result with the ability to convert the response to your target class. - * Common exceptions will be caught and translated to errors. - */ - @JvmStatic - fun fromWebSocketRequest( - signalWebSocket: SignalWebSocket, - request: WebSocketRequestMessage, - webSocketResponseConverter: WebSocketResponseConverter - ): NetworkResult = try { - val result = signalWebSocket.request(request) - .map { response: WebsocketResponse -> webSocketResponseConverter.convert(response) } - .blockingGet() - Success(result) - } catch (e: NonSuccessfulResponseCodeException) { - StatusCodeError(e) - } catch (e: IOException) { - NetworkError(e) + } catch (e: InterruptedException) { + NetworkError(PushNetworkException(e)) } catch (e: Throwable) { ApplicationError(e) } @@ -308,6 +329,37 @@ sealed class NetworkResult( fun interface WebSocketResponseConverter { @Throws(Exception::class) - fun convert(response: WebsocketResponse): T + fun convert(response: WebsocketResponse): NetworkResult + } + + class DefaultWebSocketConverter(private val responseJsonClass: KClass) : WebSocketResponseConverter { + override fun convert(response: WebsocketResponse): NetworkResult { + return if (response.status < 200 || response.status > 299) { + response.toStatusCodeError() + } else { + response.toSuccess(responseJsonClass) + } + } + } + + class LongPollingWebSocketConverter(private val responseJsonClass: KClass) : WebSocketResponseConverter { + override fun convert(response: WebsocketResponse): NetworkResult { + return if (response.status == 204 || response.status < 200 || response.status > 299) { + response.toStatusCodeError() + } else { + response.toSuccess(responseJsonClass) + } + } } } + +private fun WebsocketResponse.toStatusCodeError(): NetworkResult { + return StatusCodeError(NonSuccessfulResponseCodeException(this.status, "", this.body)) +} + +private fun WebsocketResponse.toSuccess(responseJsonClass: KClass): NetworkResult { + if (responseJsonClass == Unit::class) { + return Success(responseJsonClass.cast(Unit)) + } + return Success(JsonUtil.fromJson(this.body, responseJsonClass.java)) +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java index 3df7ca0468..48b9e60b0a 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceAccountManager.java @@ -297,15 +297,6 @@ public class SignalServiceAccountManager { return pushServiceSocket.getAccountDataReport(); } - - public List getDevices() throws IOException { - return this.pushServiceSocket.getDevices(); - } - - public void removeDevice(int deviceId) throws IOException { - this.pushServiceSocket.removeDevice(deviceId); - } - public List getTurnServerInfo() throws IOException { List relays = this.pushServiceSocket.getCallingRelays().getRelays(); return relays != null ? relays : Collections.emptyList(); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/link/LinkDeviceApi.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/link/LinkDeviceApi.kt index ecc54b95da..41f296fa20 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/link/LinkDeviceApi.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/link/LinkDeviceApi.kt @@ -6,6 +6,8 @@ package org.whispersystems.signalservice.api.link import okio.ByteString.Companion.toByteString +import org.signal.core.util.Base64.encodeWithPadding +import org.signal.core.util.urlEncode import org.signal.libsignal.protocol.IdentityKeyPair import org.signal.libsignal.protocol.ecc.ECPublicKey import org.signal.libsignal.zkgroup.profiles.ProfileKey @@ -13,18 +15,54 @@ import org.whispersystems.signalservice.api.NetworkResult import org.whispersystems.signalservice.api.backup.MediaRootBackupKey import org.whispersystems.signalservice.api.backup.MessageBackupKey import org.whispersystems.signalservice.api.kbs.MasterKey +import org.whispersystems.signalservice.api.messages.multidevice.DeviceInfo import org.whispersystems.signalservice.api.push.ServiceId.ACI import org.whispersystems.signalservice.api.push.ServiceId.PNI +import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.internal.crypto.PrimaryProvisioningCipher +import org.whispersystems.signalservice.internal.delete +import org.whispersystems.signalservice.internal.get +import org.whispersystems.signalservice.internal.push.DeviceInfoList import org.whispersystems.signalservice.internal.push.ProvisionMessage +import org.whispersystems.signalservice.internal.push.ProvisioningMessage import org.whispersystems.signalservice.internal.push.ProvisioningVersion -import org.whispersystems.signalservice.internal.push.PushServiceSocket -import kotlin.math.min +import org.whispersystems.signalservice.internal.put +import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds /** * Class to interact with device-linking endpoints. */ -class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) { +class LinkDeviceApi( + private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket +) { + /** + * Fetches a list of linked devices. + * + * GET /v1/devices + * + * - 200: Success + */ + fun getDevices(): NetworkResult> { + val request = WebSocketRequestMessage.get("/v1/devices") + return NetworkResult + .fromWebSocketRequest(authWebSocket, request, DeviceInfoList::class) + .map { it.getDevices() } + } + + /** + * Remove and unlink a linked device. + * + * DELETE /v1/devices/{id} + * + * - 200: Success + */ + fun removeDevice(deviceId: Int): NetworkResult { + val request = WebSocketRequestMessage.delete("/v1/devices/$deviceId") + return NetworkResult + .fromWebSocketRequest(authWebSocket, request) + } /** * Fetches a new verification code that lets you link a new device. @@ -36,15 +74,15 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) { * - 429: Rate-limited. */ fun getDeviceVerificationCode(): NetworkResult { - return NetworkResult.fromFetch { - pushServiceSocket.getLinkedDeviceVerificationCode() - } + val request = WebSocketRequestMessage.get("/v1/devices/provisioning/code") + return NetworkResult + .fromWebSocketRequest(authWebSocket, request, LinkedDeviceVerificationCodeResponse::class) } /** * Links a new device to the account. * - * PUT /v1/devices/link + * PUT /v1/provisioning/[deviceIdentifier] * * - 200: Success. * - 403: Account not found or incorrect verification code. @@ -67,45 +105,50 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) { code: String, ephemeralMessageBackupKey: MessageBackupKey? ): NetworkResult { - return NetworkResult.fromFetch { - val cipher = PrimaryProvisioningCipher(deviceKey) - val message = ProvisionMessage( - aciIdentityKeyPublic = aciIdentityKeyPair.publicKey.serialize().toByteString(), - aciIdentityKeyPrivate = aciIdentityKeyPair.privateKey.serialize().toByteString(), - pniIdentityKeyPublic = pniIdentityKeyPair.publicKey.serialize().toByteString(), - pniIdentityKeyPrivate = pniIdentityKeyPair.privateKey.serialize().toByteString(), - aci = aci.toString(), - pni = pni.toStringWithoutPrefix(), - number = e164, - profileKey = profileKey.serialize().toByteString(), - provisioningCode = code, - provisioningVersion = ProvisioningVersion.CURRENT.value, - masterKey = masterKey.serialize().toByteString(), - mediaRootBackupKey = mediaRootBackupKey.value.toByteString(), - ephemeralBackupKey = ephemeralMessageBackupKey?.value?.toByteString() - ) - val ciphertext = cipher.encrypt(message) + val cipher = PrimaryProvisioningCipher(deviceKey) + val message = ProvisionMessage( + aciIdentityKeyPublic = aciIdentityKeyPair.publicKey.serialize().toByteString(), + aciIdentityKeyPrivate = aciIdentityKeyPair.privateKey.serialize().toByteString(), + pniIdentityKeyPublic = pniIdentityKeyPair.publicKey.serialize().toByteString(), + pniIdentityKeyPrivate = pniIdentityKeyPair.privateKey.serialize().toByteString(), + aci = aci.toString(), + pni = pni.toStringWithoutPrefix(), + number = e164, + profileKey = profileKey.serialize().toByteString(), + provisioningCode = code, + provisioningVersion = ProvisioningVersion.CURRENT.value, + masterKey = masterKey.serialize().toByteString(), + mediaRootBackupKey = mediaRootBackupKey.value.toByteString(), + ephemeralBackupKey = ephemeralMessageBackupKey?.value?.toByteString() + ) + val ciphertext: ByteArray = cipher.encrypt(message) + val body = ProvisioningMessage(encodeWithPadding(ciphertext)) - pushServiceSocket.sendProvisioningMessage(deviceIdentifier, ciphertext) - } + val request = WebSocketRequestMessage.put("/v1/provisioning/${deviceIdentifier.urlEncode()}", body) + return NetworkResult.fromWebSocketRequest(authWebSocket, request) } /** * A "long-polling" endpoint that will return once the device has successfully been linked. * - * @param timeoutSeconds The max amount of time to wait. Capped at 30 seconds. + * @param timeout The max amount of time to wait. Capped at 30 seconds. * - * GET /v1/devices/wait_for_linked_device/{token} + * GET /v1/devices/wait_for_linked_device/[token]?timeout=[timeout] * * - 200: Success, a new device was linked associated with the provided token. * - 204: No device was linked before the max waiting time elapsed. * - 400: Invalid token/timeout. * - 429: Rate-limited. */ - fun waitForLinkedDevice(token: String, timeoutSeconds: Int = 30): NetworkResult { - return NetworkResult.fromFetch { - pushServiceSocket.waitForLinkedDevice(token, min(timeoutSeconds, 30)) - } + fun waitForLinkedDevice(token: String, timeout: Duration = 30.seconds): NetworkResult { + val request = WebSocketRequestMessage.get("/v1/devices/wait_for_linked_device/${token.urlEncode()}?timeout=${timeout.inWholeSeconds}") + return NetworkResult + .fromWebSocketRequest( + signalWebSocket = authWebSocket, + request = request, + timeout = timeout, + webSocketResponseConverter = NetworkResult.LongPollingWebSocketConverter(WaitForLinkedDeviceResponse::class) + ) } /** @@ -118,18 +161,16 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) { * - 429: Rate-limited. */ fun setTransferArchive(destinationDeviceId: Int, destinationDeviceCreated: Long, cdn: Int, cdnKey: String): NetworkResult { - return NetworkResult.fromFetch { - pushServiceSocket.setLinkedDeviceTransferArchive( - SetLinkedDeviceTransferArchiveRequest( - destinationDeviceId = destinationDeviceId, - destinationDeviceCreated = destinationDeviceCreated, - transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.CdnInfo( - cdn = cdn, - key = cdnKey - ) - ) + val body = SetLinkedDeviceTransferArchiveRequest( + destinationDeviceId = destinationDeviceId, + destinationDeviceCreated = destinationDeviceCreated, + transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.CdnInfo( + cdn = cdn, + key = cdnKey ) - } + ) + val request = WebSocketRequestMessage.put("/v1/devices/transfer_archive", body) + return NetworkResult.fromWebSocketRequest(authWebSocket, request) } /** @@ -143,31 +184,26 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) { * - 429: Rate-limited. */ fun setTransferArchiveError(destinationDeviceId: Int, destinationDeviceCreated: Long, error: TransferArchiveError): NetworkResult { - return NetworkResult.fromFetch { - pushServiceSocket.setLinkedDeviceTransferArchive( - SetLinkedDeviceTransferArchiveRequest( - destinationDeviceId = destinationDeviceId, - destinationDeviceCreated = destinationDeviceCreated, - transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.Error( - error - ) - ) - ) - } + val body = SetLinkedDeviceTransferArchiveRequest( + destinationDeviceId = destinationDeviceId, + destinationDeviceCreated = destinationDeviceCreated, + transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.Error(error) + ) + val request = WebSocketRequestMessage.put("/v1/devices/transfer_archive", body) + return NetworkResult.fromWebSocketRequest(authWebSocket, request) } /** * Sets the name for a linked device * - * PUT /v1/accounts/name + * PUT /v1/accounts/name?deviceId=[deviceId] * * - 204: Success. * - 403: Not authorized to change the name of the device with the given ID * - 404: No device found with the given ID */ fun setDeviceName(encryptedDeviceName: String, deviceId: Int): NetworkResult { - return NetworkResult.fromFetch { - pushServiceSocket.setDeviceName(deviceId, SetDeviceNameRequest(encryptedDeviceName)) - } + val request = WebSocketRequestMessage.put("/v1/accounts/name?deviceId=$deviceId", SetDeviceNameRequest(encryptedDeviceName)) + return NetworkResult.fromWebSocketRequest(authWebSocket, request) } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt index f78837cc1d..691b11ff29 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt @@ -23,6 +23,7 @@ import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMess import org.whispersystems.signalservice.internal.websocket.WebsocketResponse import java.io.IOException import java.util.concurrent.TimeoutException +import kotlin.time.Duration /** * Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket @@ -100,6 +101,14 @@ sealed class SignalWebSocket( } } + fun request(request: WebSocketRequestMessage, timeout: Duration): Single { + return try { + getWebSocket().sendRequest(request, timeout.inWholeSeconds) + } catch (e: IOException) { + Single.error(e) + } + } + @Throws(IOException::class) fun sendAck(response: EnvelopeResponse) { getWebSocket().sendResponse(response.websocketRequest.getWebSocketResponse()) diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/WebSocketRequestExt.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/WebSocketRequestExt.kt new file mode 100644 index 0000000000..6e633c80cc --- /dev/null +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/WebSocketRequestExt.kt @@ -0,0 +1,45 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.signalservice.internal + +import org.whispersystems.signalservice.internal.util.JsonUtil +import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage +import java.security.SecureRandom + +/** + * Create a basic GET web socket request + */ +fun WebSocketRequestMessage.Companion.get(path: String): WebSocketRequestMessage { + return WebSocketRequestMessage( + verb = "GET", + path = path, + id = SecureRandom().nextLong() + ) +} + +/** + * Create a basic DELETE web socket request + */ +fun WebSocketRequestMessage.Companion.delete(path: String): WebSocketRequestMessage { + return WebSocketRequestMessage( + verb = "DELETE", + path = path, + id = SecureRandom().nextLong() + ) +} + +/** + * Create a basic PUT web socket request, where body is JSON-ified. + */ +fun WebSocketRequestMessage.Companion.put(path: String, body: Any): WebSocketRequestMessage { + return WebSocketRequestMessage( + verb = "PUT", + path = path, + headers = listOf("content-type:application/json"), + body = JsonUtil.toJsonByteString(body), + id = SecureRandom().nextLong() + ) +} diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java index 0602ba09a3..fb07ed10b9 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/push/PushServiceSocket.java @@ -253,11 +253,7 @@ public class PushServiceSocket { private static final String CALLING_RELAYS = "/v2/calling/relays"; - private static final String PROVISIONING_CODE_PATH = "/v1/devices/provisioning/code"; private static final String PROVISIONING_MESSAGE_PATH = "/v1/provisioning/%s"; - private static final String DEVICE_PATH = "/v1/devices/%s"; - private static final String WAIT_FOR_DEVICES_PATH = "/v1/devices/wait_for_linked_device/%s?timeout=%s"; - private static final String TRANSFER_ARCHIVE_PATH = "/v1/devices/transfer_archive"; private static final String SET_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s"; private static final String WAIT_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s?timeout=%s"; @@ -698,29 +694,6 @@ public class PushServiceSocket { makeServiceRequest(SET_ACCOUNT_ATTRIBUTES, "PUT", JsonUtil.toJson(accountAttributes)); } - public LinkedDeviceVerificationCodeResponse getLinkedDeviceVerificationCode() throws IOException { - String responseText = makeServiceRequest(PROVISIONING_CODE_PATH, "GET", null, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE); - return JsonUtil.fromJson(responseText, LinkedDeviceVerificationCodeResponse.class); - } - - public List getDevices() throws IOException { - String responseText = makeServiceRequest(String.format(DEVICE_PATH, ""), "GET", null); - return JsonUtil.fromJson(responseText, DeviceInfoList.class).getDevices(); - } - - /** - * This is a long-polling endpoint that relies on the fact that our normal connection timeout is already 30s. - */ - public WaitForLinkedDeviceResponse waitForLinkedDevice(String token, int timeoutSeconds) throws IOException { - String response = makeServiceRequest(String.format(Locale.US, WAIT_FOR_DEVICES_PATH, token, timeoutSeconds), "GET", null, NO_HEADERS, LONG_POLL_HANDLER, SealedSenderAccess.NONE); - return JsonUtil.fromJsonResponse(response, WaitForLinkedDeviceResponse.class); - } - - public void setLinkedDeviceTransferArchive(SetLinkedDeviceTransferArchiveRequest request) throws IOException { - String body = JsonUtil.toJson(request); - makeServiceRequest(String.format(Locale.US, TRANSFER_ARCHIVE_PATH), "PUT", body, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE); - } - public void setRestoreMethodChosen(@Nonnull String token, @Nonnull RestoreMethodBody request) throws IOException { String body = JsonUtil.toJson(request); makeServiceRequest(String.format(Locale.US, SET_RESTORE_METHOD_PATH, urlEncode(token)), "PUT", body, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE); @@ -734,10 +707,6 @@ public class PushServiceSocket { return JsonUtil.fromJsonResponse(response, RestoreMethodBody.class); } - public void removeDevice(long deviceId) throws IOException { - makeServiceRequest(String.format(DEVICE_PATH, String.valueOf(deviceId)), "DELETE", null); - } - public void sendProvisioningMessage(String destination, byte[] body) throws IOException { makeServiceRequest(String.format(PROVISIONING_MESSAGE_PATH, urlEncode(destination)), "PUT", JsonUtil.toJson(new ProvisioningMessage(Base64.encodeWithPadding(body)))); diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 4f860f50f8..9ef5e02519 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -37,6 +37,7 @@ import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock +import kotlin.time.Duration import kotlin.time.Duration.Companion.seconds import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest @@ -95,17 +96,16 @@ class LibSignalChatConnection( const val SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY = "X-Signal-Timestamp" private val TAG = Log.tag(LibSignalChatConnection::class.java) - private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds private val KEEP_ALIVE_REQUEST = LibSignalRequest( "GET", "/v1/keepalive", emptyMap(), ByteArray(0), - SEND_TIMEOUT.toInt() + WebSocketConnection.DEFAULT_SEND_TIMEOUT.inWholeMilliseconds.toInt() ) - private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Long = SEND_TIMEOUT): LibSignalRequest { + private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT): LibSignalRequest { return LibSignalRequest( this.verb?.uppercase() ?: "GET", this.path ?: "", @@ -117,7 +117,7 @@ class LibSignalChatConnection( parts[0] to parts[1] }, this.body?.toByteArray() ?: byteArrayOf(), - timeout.toInt() + timeout.inWholeMilliseconds.toInt() ) } } @@ -254,7 +254,7 @@ class LibSignalChatConnection( } } - override fun sendRequest(request: WebSocketRequestMessage): Single { + override fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single { CHAT_SERVICE_LOCK.withLock { if (isDead()) { return Single.error(IOException("$name is closed!")) @@ -294,7 +294,7 @@ class LibSignalChatConnection( return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io()) } - val internalRequest = request.toLibSignalRequest() + val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds) chatConnection!!.send(internalRequest) .whenComplete( onSuccess = { response -> diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java index cbf771f642..9aee47a8fb 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/OkHttpWebSocketConnection.java @@ -1,5 +1,6 @@ package org.whispersystems.signalservice.internal.websocket; +import org.jetbrains.annotations.NotNull; import org.signal.libsignal.protocol.logging.Log; import org.signal.libsignal.protocol.util.Pair; import org.whispersystems.signalservice.api.push.TrustStore; @@ -227,7 +228,7 @@ public class OkHttpWebSocketConnection extends WebSocketListener implements WebS } @Override - public synchronized Single sendRequest(WebSocketRequestMessage request) throws IOException { + public synchronized Single sendRequest(@NotNull WebSocketRequestMessage request, long timeoutSeconds) throws IOException { if (client == null) { throw new IOException("No connection!"); } @@ -247,7 +248,7 @@ public class OkHttpWebSocketConnection extends WebSocketListener implements WebS return single.subscribeOn(Schedulers.io()) .observeOn(Schedulers.io()) - .timeout(10, TimeUnit.SECONDS, Schedulers.io()); + .timeout(timeoutSeconds, TimeUnit.SECONDS, Schedulers.io()); } @Override diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt index 259f47862d..0564da44a4 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt @@ -6,6 +6,7 @@ import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import java.io.IOException import java.util.Optional import java.util.concurrent.TimeoutException +import kotlin.time.Duration.Companion.seconds /** * Common interface for the web socket connection API @@ -15,6 +16,10 @@ import java.util.concurrent.TimeoutException * - LibSignalChatConnection - the wrapper around libsignal's [org.signal.libsignal.net.ChatService] */ interface WebSocketConnection { + companion object { + val DEFAULT_SEND_TIMEOUT = 10.seconds + } + val name: String fun connect(): Observable @@ -24,7 +29,12 @@ interface WebSocketConnection { fun disconnect() @Throws(IOException::class) - fun sendRequest(request: WebSocketRequestMessage): Single + fun sendRequest(request: WebSocketRequestMessage): Single { + return sendRequest(request, DEFAULT_SEND_TIMEOUT.inWholeSeconds) + } + + @Throws(IOException::class) + fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single @Throws(IOException::class) fun sendKeepAlive()