Convert device linking apis to use websockets.

This commit is contained in:
Cody Henthorne
2025-03-06 16:17:55 -05:00
committed by Michelle Tang
parent 451d12ed53
commit c38342e2fb
21 changed files with 334 additions and 200 deletions

View File

@@ -47,8 +47,8 @@ import androidx.transition.TransitionManager;
import com.google.android.material.dialog.MaterialAlertDialogBuilder;
import org.signal.core.util.concurrent.JvmRxExtensions;
import org.signal.core.util.concurrent.LifecycleDisposable;
import org.signal.core.util.concurrent.RxExtensions;
import org.signal.core.util.concurrent.SimpleTask;
import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.calls.YouAreAlreadyInACallSnackbar;
@@ -722,7 +722,7 @@ public final class ContactSelectionListFragment extends LoggingFragment {
SimpleTask.run(getViewLifecycleOwner().getLifecycle(), () -> {
try {
return RxExtensions.safeBlockingGet(UsernameRepository.fetchAciForUsername(UsernameUtil.sanitizeUsernameFromSearch(username)));
return JvmRxExtensions.safeBlockingGet(UsernameRepository.fetchAciForUsername(UsernameUtil.sanitizeUsernameFromSearch(username)));
} catch (InterruptedException e) {
Log.w(TAG, "Interrupted?", e);
return UsernameAciFetchResult.NetworkError.INSTANCE;

View File

@@ -377,7 +377,7 @@ object AppDependencies {
fun provideArchiveApi(pushServiceSocket: PushServiceSocket): ArchiveApi
fun provideKeysApi(pushServiceSocket: PushServiceSocket): KeysApi
fun provideAttachmentApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket, pushServiceSocket: PushServiceSocket): AttachmentApi
fun provideLinkDeviceApi(pushServiceSocket: PushServiceSocket): LinkDeviceApi
fun provideLinkDeviceApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket): LinkDeviceApi
fun provideRegistrationApi(pushServiceSocket: PushServiceSocket): RegistrationApi
fun provideStorageServiceApi(pushServiceSocket: PushServiceSocket): StorageServiceApi
fun provideAuthWebSocket(signalServiceConfigurationSupplier: Supplier<SignalServiceConfiguration>, libSignalNetworkSupplier: Supplier<Network>): SignalWebSocket.AuthenticatedWebSocket

View File

@@ -478,8 +478,8 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
}
@Override
public @NonNull LinkDeviceApi provideLinkDeviceApi(@NonNull PushServiceSocket pushServiceSocket) {
return new LinkDeviceApi(pushServiceSocket);
public @NonNull LinkDeviceApi provideLinkDeviceApi(@NonNull SignalWebSocket.AuthenticatedWebSocket authWebSocket) {
return new LinkDeviceApi(authWebSocket);
}
@Override

View File

@@ -146,7 +146,7 @@ class NetworkDependenciesModule(
}
val linkDeviceApi: LinkDeviceApi by lazy {
provider.provideLinkDeviceApi(pushServiceSocket)
provider.provideLinkDeviceApi(authWebSocket)
}
val registrationApi: RegistrationApi by lazy {

View File

@@ -68,7 +68,10 @@ class LinkedDeviceInactiveCheckJob private constructor(
}
val devices = try {
AppDependencies.signalServiceAccountManager.devices
AppDependencies
.linkDeviceApi
.getDevices()
.successOrThrow()
.filter { it.id != SignalServiceAddress.DEFAULT_DEVICE_ID }
} catch (e: IOException) {
return Result.retry(defaultBackoff())

View File

@@ -45,29 +45,32 @@ object LinkDeviceRepository {
private val TAG = Log.tag(LinkDeviceRepository::class)
fun removeDevice(deviceId: Int): Boolean {
return try {
val accountManager = AppDependencies.signalServiceAccountManager
accountManager.removeDevice(deviceId)
LinkedDeviceInactiveCheckJob.enqueue()
true
} catch (e: IOException) {
Log.w(TAG, e)
false
return when (val result = AppDependencies.linkDeviceApi.removeDevice(deviceId)) {
is NetworkResult.Success -> {
LinkedDeviceInactiveCheckJob.enqueue()
true
}
else -> {
Log.w(TAG, "Unable to remove device", result.getCause())
false
}
}
}
fun loadDevices(): List<Device>? {
val accountManager = AppDependencies.signalServiceAccountManager
return try {
val devices: List<Device> = accountManager.getDevices()
.filter { d: DeviceInfo -> d.getId() != SignalServiceAddress.DEFAULT_DEVICE_ID }
.map { deviceInfo: DeviceInfo -> deviceInfo.toDevice() }
.sortedBy { it.createdMillis }
.toList()
devices
} catch (e: IOException) {
Log.w(TAG, e)
null
return when (val result = AppDependencies.linkDeviceApi.getDevices()) {
is NetworkResult.Success -> {
result
.result
.filter { d: DeviceInfo -> d.getId() != SignalServiceAddress.DEFAULT_DEVICE_ID }
.map { deviceInfo: DeviceInfo -> deviceInfo.toDevice() }
.sortedBy { it.createdMillis }
.toList()
}
else -> {
Log.w(TAG, "Unable to load device", result.getCause())
null
}
}
}
@@ -132,12 +135,12 @@ object LinkDeviceRepository {
val verificationCodeResult: LinkedDeviceVerificationCodeResponse = when (val result = SignalNetwork.linkDevice.getDeviceVerificationCode()) {
is NetworkResult.Success -> result.result
is NetworkResult.ApplicationError -> throw result.throwable
is NetworkResult.NetworkError -> return LinkDeviceResult.NetworkError
is NetworkResult.NetworkError -> return LinkDeviceResult.NetworkError(result.exception)
is NetworkResult.StatusCodeError -> {
return when (result.code) {
411 -> LinkDeviceResult.LimitExceeded
429 -> LinkDeviceResult.NetworkError
else -> LinkDeviceResult.NetworkError
429 -> LinkDeviceResult.NetworkError(result.exception)
else -> LinkDeviceResult.NetworkError(result.exception)
}
}
}
@@ -171,15 +174,15 @@ object LinkDeviceRepository {
LinkDeviceResult.Success(verificationCodeResult.tokenIdentifier)
}
is NetworkResult.ApplicationError -> throw deviceLinkResult.throwable
is NetworkResult.NetworkError -> LinkDeviceResult.NetworkError
is NetworkResult.NetworkError -> LinkDeviceResult.NetworkError(deviceLinkResult.exception)
is NetworkResult.StatusCodeError -> {
when (deviceLinkResult.code) {
403 -> LinkDeviceResult.NoDevice
409 -> LinkDeviceResult.NoDevice
411 -> LinkDeviceResult.LimitExceeded
422 -> LinkDeviceResult.NetworkError
429 -> LinkDeviceResult.NetworkError
else -> LinkDeviceResult.NetworkError
422 -> LinkDeviceResult.NetworkError(deviceLinkResult.exception)
429 -> LinkDeviceResult.NetworkError(deviceLinkResult.exception)
else -> LinkDeviceResult.NetworkError(deviceLinkResult.exception)
}
}
}
@@ -200,7 +203,7 @@ object LinkDeviceRepository {
Log.d(TAG, "[waitForDeviceToBeLinked] Willing to wait for $timeRemaining ms...")
val result = SignalNetwork.linkDevice.waitForLinkedDevice(
token = token,
timeoutSeconds = timeRemaining.milliseconds.inWholeSeconds.toInt()
timeout = timeRemaining.milliseconds
)
when (result) {
@@ -422,7 +425,7 @@ object LinkDeviceRepository {
data object None : LinkDeviceResult
data class Success(val token: String) : LinkDeviceResult
data object NoDevice : LinkDeviceResult
data object NetworkError : LinkDeviceResult
data class NetworkError(val error: Throwable) : LinkDeviceResult
data object KeyError : LinkDeviceResult
data object LimitExceeded : LinkDeviceResult
data object BadCode : LinkDeviceResult

View File

@@ -288,7 +288,7 @@ class LinkDeviceViewModel : ViewModel() {
Log.d(TAG, "[addDeviceWithSync] Got result: $result")
if (result !is LinkDeviceResult.Success) {
Log.w(TAG, "[addDeviceWithSync] Unable to link device $result")
Log.w(TAG, "[addDeviceWithSync] Unable to link device $result", if (result is LinkDeviceResult.NetworkError) result.error else null)
_state.update {
it.copy(
dialogState = DialogState.None
@@ -377,7 +377,7 @@ class LinkDeviceViewModel : ViewModel() {
}
if (result !is LinkDeviceResult.Success) {
Log.w(TAG, "Unable to link device $result")
Log.w(TAG, "Unable to link device $result", if (result is LinkDeviceResult.NetworkError) result.error else null)
_state.update {
it.copy(
dialogState = DialogState.None

View File

@@ -23,7 +23,7 @@ import androidx.fragment.app.FragmentManager;
import com.google.android.material.dialog.MaterialAlertDialogBuilder;
import org.signal.core.util.concurrent.RxExtensions;
import org.signal.core.util.concurrent.JvmRxExtensions;
import org.signal.core.util.concurrent.SignalExecutors;
import org.signal.core.util.concurrent.SimpleTask;
import org.signal.core.util.logging.Log;
@@ -460,7 +460,7 @@ public class CommunicationActions {
SimpleTask.run(() -> {
try {
UsernameLinkConversionResult result = RxExtensions.safeBlockingGet(UsernameRepository.fetchUsernameAndAciFromLink(link));
UsernameLinkConversionResult result = JvmRxExtensions.safeBlockingGet(UsernameRepository.fetchUsernameAndAciFromLink(link));
// TODO we could be better here and report different types of errors to the UI
if (result instanceof UsernameLinkConversionResult.Success success) {

View File

@@ -226,7 +226,7 @@ class MockApplicationDependencyProvider : AppDependencies.Provider {
return mockk(relaxed = true)
}
override fun provideLinkDeviceApi(pushServiceSocket: PushServiceSocket): LinkDeviceApi {
override fun provideLinkDeviceApi(authWebSocket: SignalWebSocket.AuthenticatedWebSocket): LinkDeviceApi {
return mockk(relaxed = true)
}

View File

@@ -54,6 +54,8 @@ dependencies {
implementation(libs.kotlinx.coroutines.core)
implementation(libs.kotlinx.coroutines.core.jvm)
implementation(libs.google.libphonenumber)
implementation(libs.rxjava3.rxjava)
implementation(libs.rxjava3.rxkotlin)
testImplementation(testLibs.junit.junit)
testImplementation(testLibs.assertk)

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("JvmRxExtensions")
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Single
/**
* Throw an [InterruptedException] if a [Single.blockingGet] call is interrupted. This can
* happen when being called by code already within an Rx chain that is disposed.
*
* [Single.blockingGet] is considered harmful and should not be used.
*/
@Throws(InterruptedException::class)
fun <T : Any> Single<T>.safeBlockingGet(): T {
try {
return blockingGet()
} catch (e: RuntimeException) {
val cause = e.cause
if (cause is InterruptedException) {
throw cause
} else {
throw e
}
}
}

View File

@@ -1,8 +1,12 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("RxExtensions")
package org.signal.core.util.concurrent
import android.annotation.SuppressLint
import androidx.lifecycle.LifecycleOwner
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
@@ -13,27 +17,6 @@ import io.reactivex.rxjava3.kotlin.addTo
import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.subjects.Subject
/**
* Throw an [InterruptedException] if a [Single.blockingGet] call is interrupted. This can
* happen when being called by code already within an Rx chain that is disposed.
*
* [Single.blockingGet] is considered harmful and should not be used.
*/
@SuppressLint("UnsafeBlockingGet")
@Throws(InterruptedException::class)
fun <T : Any> Single<T>.safeBlockingGet(): T {
try {
return blockingGet()
} catch (e: RuntimeException) {
val cause = e.cause
if (cause is InterruptedException) {
throw cause
} else {
throw e
}
}
}
fun <T : Any> Flowable<T>.observe(viewLifecycleOwner: LifecycleOwner, onNext: (T) -> Unit) {
val lifecycleDisposable = LifecycleDisposable()
lifecycleDisposable.bindTo(viewLifecycleOwner)

View File

@@ -5,15 +5,21 @@
package org.whispersystems.signalservice.api
import org.signal.core.util.concurrent.safeBlockingGet
import org.whispersystems.signalservice.api.NetworkResult.StatusCodeError
import org.whispersystems.signalservice.api.NetworkResult.Success
import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException
import org.whispersystems.signalservice.api.push.exceptions.PushNetworkException
import org.whispersystems.signalservice.api.websocket.SignalWebSocket
import org.whispersystems.signalservice.internal.util.JsonUtil
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import org.whispersystems.signalservice.internal.websocket.WebsocketResponse
import java.io.IOException
import java.util.concurrent.TimeoutException
import kotlin.reflect.KClass
import kotlin.reflect.cast
import kotlin.time.Duration
typealias StatusCodeErrorAction = (NetworkResult.StatusCodeError<*>) -> Unit
@@ -51,49 +57,64 @@ sealed class NetworkResult<T>(
ApplicationError(e)
}
/**
* A convenience method to convert a websocket request into a network result.
* Common HTTP errors will be translated to [StatusCodeError]s.
*/
@JvmStatic
fun fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage
): NetworkResult<Unit> = fromWebSocketRequest(
signalWebSocket = signalWebSocket,
request = request,
clazz = Unit::class
)
/**
* A convenience method to convert a websocket request into a network result with simple conversion of the response body to the desired class.
* Common exceptions will be caught and translated to errors.
* Common HTTP errors will be translated to [StatusCodeError]s.
*/
@JvmStatic
fun <T : Any> fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage,
clazz: KClass<T>
clazz: KClass<T>,
timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT
): NetworkResult<T> {
return fromWebSocketRequest(
signalWebSocket = signalWebSocket,
request = request,
timeout = timeout,
webSocketResponseConverter = DefaultWebSocketConverter(clazz)
)
}
/**
* A convenience method to convert a websocket request into a network result with the ability to fully customize the conversion of the response.
* Common HTTP errors will be translated to [StatusCodeError]s.
*/
@JvmStatic
fun <T : Any> fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage,
timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT,
webSocketResponseConverter: WebSocketResponseConverter<T>
): NetworkResult<T> = try {
val result: Result<T> = signalWebSocket.request(request)
.map { response: WebsocketResponse -> Result.success(JsonUtil.fromJson(response.body, clazz.java)) }
.onErrorReturn { Result.failure<T>(it) }
.blockingGet()
Success(result.getOrThrow())
val result: Result<NetworkResult<T>> = signalWebSocket.request(request, timeout)
.map { response: WebsocketResponse -> Result.success(webSocketResponseConverter.convert(response)) }
.onErrorReturn { Result.failure(it) }
.safeBlockingGet()
result.getOrThrow()
} catch (e: NonSuccessfulResponseCodeException) {
StatusCodeError(e)
} catch (e: IOException) {
NetworkError(e)
} catch (e: TimeoutException) {
NetworkError(PushNetworkException(e))
} catch (e: Throwable) {
ApplicationError(e)
}
/**
* A convenience method to convert a websocket request into a network result with the ability to convert the response to your target class.
* Common exceptions will be caught and translated to errors.
*/
@JvmStatic
fun <T : Any> fromWebSocketRequest(
signalWebSocket: SignalWebSocket,
request: WebSocketRequestMessage,
webSocketResponseConverter: WebSocketResponseConverter<T>
): NetworkResult<T> = try {
val result = signalWebSocket.request(request)
.map { response: WebsocketResponse -> webSocketResponseConverter.convert(response) }
.blockingGet()
Success(result)
} catch (e: NonSuccessfulResponseCodeException) {
StatusCodeError(e)
} catch (e: IOException) {
NetworkError(e)
} catch (e: InterruptedException) {
NetworkError(PushNetworkException(e))
} catch (e: Throwable) {
ApplicationError(e)
}
@@ -308,6 +329,37 @@ sealed class NetworkResult<T>(
fun interface WebSocketResponseConverter<T> {
@Throws(Exception::class)
fun convert(response: WebsocketResponse): T
fun convert(response: WebsocketResponse): NetworkResult<T>
}
class DefaultWebSocketConverter<T : Any>(private val responseJsonClass: KClass<T>) : WebSocketResponseConverter<T> {
override fun convert(response: WebsocketResponse): NetworkResult<T> {
return if (response.status < 200 || response.status > 299) {
response.toStatusCodeError()
} else {
response.toSuccess(responseJsonClass)
}
}
}
class LongPollingWebSocketConverter<T : Any>(private val responseJsonClass: KClass<T>) : WebSocketResponseConverter<T> {
override fun convert(response: WebsocketResponse): NetworkResult<T> {
return if (response.status == 204 || response.status < 200 || response.status > 299) {
response.toStatusCodeError()
} else {
response.toSuccess(responseJsonClass)
}
}
}
}
private fun <T : Any> WebsocketResponse.toStatusCodeError(): NetworkResult<T> {
return StatusCodeError(NonSuccessfulResponseCodeException(this.status, "", this.body))
}
private fun <T : Any> WebsocketResponse.toSuccess(responseJsonClass: KClass<T>): NetworkResult<T> {
if (responseJsonClass == Unit::class) {
return Success(responseJsonClass.cast(Unit))
}
return Success(JsonUtil.fromJson(this.body, responseJsonClass.java))
}

View File

@@ -297,15 +297,6 @@ public class SignalServiceAccountManager {
return pushServiceSocket.getAccountDataReport();
}
public List<DeviceInfo> getDevices() throws IOException {
return this.pushServiceSocket.getDevices();
}
public void removeDevice(int deviceId) throws IOException {
this.pushServiceSocket.removeDevice(deviceId);
}
public List<TurnServerInfo> getTurnServerInfo() throws IOException {
List<TurnServerInfo> relays = this.pushServiceSocket.getCallingRelays().getRelays();
return relays != null ? relays : Collections.emptyList();

View File

@@ -6,6 +6,8 @@
package org.whispersystems.signalservice.api.link
import okio.ByteString.Companion.toByteString
import org.signal.core.util.Base64.encodeWithPadding
import org.signal.core.util.urlEncode
import org.signal.libsignal.protocol.IdentityKeyPair
import org.signal.libsignal.protocol.ecc.ECPublicKey
import org.signal.libsignal.zkgroup.profiles.ProfileKey
@@ -13,18 +15,54 @@ import org.whispersystems.signalservice.api.NetworkResult
import org.whispersystems.signalservice.api.backup.MediaRootBackupKey
import org.whispersystems.signalservice.api.backup.MessageBackupKey
import org.whispersystems.signalservice.api.kbs.MasterKey
import org.whispersystems.signalservice.api.messages.multidevice.DeviceInfo
import org.whispersystems.signalservice.api.push.ServiceId.ACI
import org.whispersystems.signalservice.api.push.ServiceId.PNI
import org.whispersystems.signalservice.api.websocket.SignalWebSocket
import org.whispersystems.signalservice.internal.crypto.PrimaryProvisioningCipher
import org.whispersystems.signalservice.internal.delete
import org.whispersystems.signalservice.internal.get
import org.whispersystems.signalservice.internal.push.DeviceInfoList
import org.whispersystems.signalservice.internal.push.ProvisionMessage
import org.whispersystems.signalservice.internal.push.ProvisioningMessage
import org.whispersystems.signalservice.internal.push.ProvisioningVersion
import org.whispersystems.signalservice.internal.push.PushServiceSocket
import kotlin.math.min
import org.whispersystems.signalservice.internal.put
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
/**
* Class to interact with device-linking endpoints.
*/
class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) {
class LinkDeviceApi(
private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket
) {
/**
* Fetches a list of linked devices.
*
* GET /v1/devices
*
* - 200: Success
*/
fun getDevices(): NetworkResult<List<DeviceInfo>> {
val request = WebSocketRequestMessage.get("/v1/devices")
return NetworkResult
.fromWebSocketRequest(authWebSocket, request, DeviceInfoList::class)
.map { it.getDevices() }
}
/**
* Remove and unlink a linked device.
*
* DELETE /v1/devices/{id}
*
* - 200: Success
*/
fun removeDevice(deviceId: Int): NetworkResult<Unit> {
val request = WebSocketRequestMessage.delete("/v1/devices/$deviceId")
return NetworkResult
.fromWebSocketRequest(authWebSocket, request)
}
/**
* Fetches a new verification code that lets you link a new device.
@@ -36,15 +74,15 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) {
* - 429: Rate-limited.
*/
fun getDeviceVerificationCode(): NetworkResult<LinkedDeviceVerificationCodeResponse> {
return NetworkResult.fromFetch {
pushServiceSocket.getLinkedDeviceVerificationCode()
}
val request = WebSocketRequestMessage.get("/v1/devices/provisioning/code")
return NetworkResult
.fromWebSocketRequest(authWebSocket, request, LinkedDeviceVerificationCodeResponse::class)
}
/**
* Links a new device to the account.
*
* PUT /v1/devices/link
* PUT /v1/provisioning/[deviceIdentifier]
*
* - 200: Success.
* - 403: Account not found or incorrect verification code.
@@ -67,45 +105,50 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) {
code: String,
ephemeralMessageBackupKey: MessageBackupKey?
): NetworkResult<Unit> {
return NetworkResult.fromFetch {
val cipher = PrimaryProvisioningCipher(deviceKey)
val message = ProvisionMessage(
aciIdentityKeyPublic = aciIdentityKeyPair.publicKey.serialize().toByteString(),
aciIdentityKeyPrivate = aciIdentityKeyPair.privateKey.serialize().toByteString(),
pniIdentityKeyPublic = pniIdentityKeyPair.publicKey.serialize().toByteString(),
pniIdentityKeyPrivate = pniIdentityKeyPair.privateKey.serialize().toByteString(),
aci = aci.toString(),
pni = pni.toStringWithoutPrefix(),
number = e164,
profileKey = profileKey.serialize().toByteString(),
provisioningCode = code,
provisioningVersion = ProvisioningVersion.CURRENT.value,
masterKey = masterKey.serialize().toByteString(),
mediaRootBackupKey = mediaRootBackupKey.value.toByteString(),
ephemeralBackupKey = ephemeralMessageBackupKey?.value?.toByteString()
)
val ciphertext = cipher.encrypt(message)
val cipher = PrimaryProvisioningCipher(deviceKey)
val message = ProvisionMessage(
aciIdentityKeyPublic = aciIdentityKeyPair.publicKey.serialize().toByteString(),
aciIdentityKeyPrivate = aciIdentityKeyPair.privateKey.serialize().toByteString(),
pniIdentityKeyPublic = pniIdentityKeyPair.publicKey.serialize().toByteString(),
pniIdentityKeyPrivate = pniIdentityKeyPair.privateKey.serialize().toByteString(),
aci = aci.toString(),
pni = pni.toStringWithoutPrefix(),
number = e164,
profileKey = profileKey.serialize().toByteString(),
provisioningCode = code,
provisioningVersion = ProvisioningVersion.CURRENT.value,
masterKey = masterKey.serialize().toByteString(),
mediaRootBackupKey = mediaRootBackupKey.value.toByteString(),
ephemeralBackupKey = ephemeralMessageBackupKey?.value?.toByteString()
)
val ciphertext: ByteArray = cipher.encrypt(message)
val body = ProvisioningMessage(encodeWithPadding(ciphertext))
pushServiceSocket.sendProvisioningMessage(deviceIdentifier, ciphertext)
}
val request = WebSocketRequestMessage.put("/v1/provisioning/${deviceIdentifier.urlEncode()}", body)
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
/**
* A "long-polling" endpoint that will return once the device has successfully been linked.
*
* @param timeoutSeconds The max amount of time to wait. Capped at 30 seconds.
* @param timeout The max amount of time to wait. Capped at 30 seconds.
*
* GET /v1/devices/wait_for_linked_device/{token}
* GET /v1/devices/wait_for_linked_device/[token]?timeout=[timeout]
*
* - 200: Success, a new device was linked associated with the provided token.
* - 204: No device was linked before the max waiting time elapsed.
* - 400: Invalid token/timeout.
* - 429: Rate-limited.
*/
fun waitForLinkedDevice(token: String, timeoutSeconds: Int = 30): NetworkResult<WaitForLinkedDeviceResponse> {
return NetworkResult.fromFetch {
pushServiceSocket.waitForLinkedDevice(token, min(timeoutSeconds, 30))
}
fun waitForLinkedDevice(token: String, timeout: Duration = 30.seconds): NetworkResult<WaitForLinkedDeviceResponse> {
val request = WebSocketRequestMessage.get("/v1/devices/wait_for_linked_device/${token.urlEncode()}?timeout=${timeout.inWholeSeconds}")
return NetworkResult
.fromWebSocketRequest(
signalWebSocket = authWebSocket,
request = request,
timeout = timeout,
webSocketResponseConverter = NetworkResult.LongPollingWebSocketConverter(WaitForLinkedDeviceResponse::class)
)
}
/**
@@ -118,18 +161,16 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) {
* - 429: Rate-limited.
*/
fun setTransferArchive(destinationDeviceId: Int, destinationDeviceCreated: Long, cdn: Int, cdnKey: String): NetworkResult<Unit> {
return NetworkResult.fromFetch {
pushServiceSocket.setLinkedDeviceTransferArchive(
SetLinkedDeviceTransferArchiveRequest(
destinationDeviceId = destinationDeviceId,
destinationDeviceCreated = destinationDeviceCreated,
transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.CdnInfo(
cdn = cdn,
key = cdnKey
)
)
val body = SetLinkedDeviceTransferArchiveRequest(
destinationDeviceId = destinationDeviceId,
destinationDeviceCreated = destinationDeviceCreated,
transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.CdnInfo(
cdn = cdn,
key = cdnKey
)
}
)
val request = WebSocketRequestMessage.put("/v1/devices/transfer_archive", body)
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
/**
@@ -143,31 +184,26 @@ class LinkDeviceApi(private val pushServiceSocket: PushServiceSocket) {
* - 429: Rate-limited.
*/
fun setTransferArchiveError(destinationDeviceId: Int, destinationDeviceCreated: Long, error: TransferArchiveError): NetworkResult<Unit> {
return NetworkResult.fromFetch {
pushServiceSocket.setLinkedDeviceTransferArchive(
SetLinkedDeviceTransferArchiveRequest(
destinationDeviceId = destinationDeviceId,
destinationDeviceCreated = destinationDeviceCreated,
transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.Error(
error
)
)
)
}
val body = SetLinkedDeviceTransferArchiveRequest(
destinationDeviceId = destinationDeviceId,
destinationDeviceCreated = destinationDeviceCreated,
transferArchive = SetLinkedDeviceTransferArchiveRequest.TransferArchive.Error(error)
)
val request = WebSocketRequestMessage.put("/v1/devices/transfer_archive", body)
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
/**
* Sets the name for a linked device
*
* PUT /v1/accounts/name
* PUT /v1/accounts/name?deviceId=[deviceId]
*
* - 204: Success.
* - 403: Not authorized to change the name of the device with the given ID
* - 404: No device found with the given ID
*/
fun setDeviceName(encryptedDeviceName: String, deviceId: Int): NetworkResult<Unit> {
return NetworkResult.fromFetch {
pushServiceSocket.setDeviceName(deviceId, SetDeviceNameRequest(encryptedDeviceName))
}
val request = WebSocketRequestMessage.put("/v1/accounts/name?deviceId=$deviceId", SetDeviceNameRequest(encryptedDeviceName))
return NetworkResult.fromWebSocketRequest(authWebSocket, request)
}
}

View File

@@ -23,6 +23,7 @@ import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMess
import org.whispersystems.signalservice.internal.websocket.WebsocketResponse
import java.io.IOException
import java.util.concurrent.TimeoutException
import kotlin.time.Duration
/**
* Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket
@@ -100,6 +101,14 @@ sealed class SignalWebSocket(
}
}
fun request(request: WebSocketRequestMessage, timeout: Duration): Single<WebsocketResponse> {
return try {
getWebSocket().sendRequest(request, timeout.inWholeSeconds)
} catch (e: IOException) {
Single.error(e)
}
}
@Throws(IOException::class)
fun sendAck(response: EnvelopeResponse) {
getWebSocket().sendResponse(response.websocketRequest.getWebSocketResponse())

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.signalservice.internal
import org.whispersystems.signalservice.internal.util.JsonUtil
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
import java.security.SecureRandom
/**
* Create a basic GET web socket request
*/
fun WebSocketRequestMessage.Companion.get(path: String): WebSocketRequestMessage {
return WebSocketRequestMessage(
verb = "GET",
path = path,
id = SecureRandom().nextLong()
)
}
/**
* Create a basic DELETE web socket request
*/
fun WebSocketRequestMessage.Companion.delete(path: String): WebSocketRequestMessage {
return WebSocketRequestMessage(
verb = "DELETE",
path = path,
id = SecureRandom().nextLong()
)
}
/**
* Create a basic PUT web socket request, where body is JSON-ified.
*/
fun WebSocketRequestMessage.Companion.put(path: String, body: Any): WebSocketRequestMessage {
return WebSocketRequestMessage(
verb = "PUT",
path = path,
headers = listOf("content-type:application/json"),
body = JsonUtil.toJsonByteString(body),
id = SecureRandom().nextLong()
)
}

View File

@@ -253,11 +253,7 @@ public class PushServiceSocket {
private static final String CALLING_RELAYS = "/v2/calling/relays";
private static final String PROVISIONING_CODE_PATH = "/v1/devices/provisioning/code";
private static final String PROVISIONING_MESSAGE_PATH = "/v1/provisioning/%s";
private static final String DEVICE_PATH = "/v1/devices/%s";
private static final String WAIT_FOR_DEVICES_PATH = "/v1/devices/wait_for_linked_device/%s?timeout=%s";
private static final String TRANSFER_ARCHIVE_PATH = "/v1/devices/transfer_archive";
private static final String SET_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s";
private static final String WAIT_RESTORE_METHOD_PATH = "/v1/devices/restore_account/%s?timeout=%s";
@@ -698,29 +694,6 @@ public class PushServiceSocket {
makeServiceRequest(SET_ACCOUNT_ATTRIBUTES, "PUT", JsonUtil.toJson(accountAttributes));
}
public LinkedDeviceVerificationCodeResponse getLinkedDeviceVerificationCode() throws IOException {
String responseText = makeServiceRequest(PROVISIONING_CODE_PATH, "GET", null, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE);
return JsonUtil.fromJson(responseText, LinkedDeviceVerificationCodeResponse.class);
}
public List<DeviceInfo> getDevices() throws IOException {
String responseText = makeServiceRequest(String.format(DEVICE_PATH, ""), "GET", null);
return JsonUtil.fromJson(responseText, DeviceInfoList.class).getDevices();
}
/**
* This is a long-polling endpoint that relies on the fact that our normal connection timeout is already 30s.
*/
public WaitForLinkedDeviceResponse waitForLinkedDevice(String token, int timeoutSeconds) throws IOException {
String response = makeServiceRequest(String.format(Locale.US, WAIT_FOR_DEVICES_PATH, token, timeoutSeconds), "GET", null, NO_HEADERS, LONG_POLL_HANDLER, SealedSenderAccess.NONE);
return JsonUtil.fromJsonResponse(response, WaitForLinkedDeviceResponse.class);
}
public void setLinkedDeviceTransferArchive(SetLinkedDeviceTransferArchiveRequest request) throws IOException {
String body = JsonUtil.toJson(request);
makeServiceRequest(String.format(Locale.US, TRANSFER_ARCHIVE_PATH), "PUT", body, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE);
}
public void setRestoreMethodChosen(@Nonnull String token, @Nonnull RestoreMethodBody request) throws IOException {
String body = JsonUtil.toJson(request);
makeServiceRequest(String.format(Locale.US, SET_RESTORE_METHOD_PATH, urlEncode(token)), "PUT", body, NO_HEADERS, UNOPINIONATED_HANDLER, SealedSenderAccess.NONE);
@@ -734,10 +707,6 @@ public class PushServiceSocket {
return JsonUtil.fromJsonResponse(response, RestoreMethodBody.class);
}
public void removeDevice(long deviceId) throws IOException {
makeServiceRequest(String.format(DEVICE_PATH, String.valueOf(deviceId)), "DELETE", null);
}
public void sendProvisioningMessage(String destination, byte[] body) throws IOException {
makeServiceRequest(String.format(PROVISIONING_MESSAGE_PATH, urlEncode(destination)), "PUT",
JsonUtil.toJson(new ProvisioningMessage(Base64.encodeWithPadding(body))));

View File

@@ -37,6 +37,7 @@ import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
import org.signal.libsignal.net.ChatConnection.Request as LibSignalRequest
@@ -95,17 +96,16 @@ class LibSignalChatConnection(
const val SIGNAL_SERVICE_ENVELOPE_TIMESTAMP_HEADER_KEY = "X-Signal-Timestamp"
private val TAG = Log.tag(LibSignalChatConnection::class.java)
private val SEND_TIMEOUT: Long = 10.seconds.inWholeMilliseconds
private val KEEP_ALIVE_REQUEST = LibSignalRequest(
"GET",
"/v1/keepalive",
emptyMap(),
ByteArray(0),
SEND_TIMEOUT.toInt()
WebSocketConnection.DEFAULT_SEND_TIMEOUT.inWholeMilliseconds.toInt()
)
private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Long = SEND_TIMEOUT): LibSignalRequest {
private fun WebSocketRequestMessage.toLibSignalRequest(timeout: Duration = WebSocketConnection.DEFAULT_SEND_TIMEOUT): LibSignalRequest {
return LibSignalRequest(
this.verb?.uppercase() ?: "GET",
this.path ?: "",
@@ -117,7 +117,7 @@ class LibSignalChatConnection(
parts[0] to parts[1]
},
this.body?.toByteArray() ?: byteArrayOf(),
timeout.toInt()
timeout.inWholeMilliseconds.toInt()
)
}
}
@@ -254,7 +254,7 @@ class LibSignalChatConnection(
}
}
override fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
override fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single<WebsocketResponse> {
CHAT_SERVICE_LOCK.withLock {
if (isDead()) {
return Single.error(IOException("$name is closed!"))
@@ -294,7 +294,7 @@ class LibSignalChatConnection(
return single.subscribeOn(Schedulers.io()).observeOn(Schedulers.io())
}
val internalRequest = request.toLibSignalRequest()
val internalRequest = request.toLibSignalRequest(timeout = timeoutSeconds.seconds)
chatConnection!!.send(internalRequest)
.whenComplete(
onSuccess = { response ->

View File

@@ -1,5 +1,6 @@
package org.whispersystems.signalservice.internal.websocket;
import org.jetbrains.annotations.NotNull;
import org.signal.libsignal.protocol.logging.Log;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.signalservice.api.push.TrustStore;
@@ -227,7 +228,7 @@ public class OkHttpWebSocketConnection extends WebSocketListener implements WebS
}
@Override
public synchronized Single<WebsocketResponse> sendRequest(WebSocketRequestMessage request) throws IOException {
public synchronized Single<WebsocketResponse> sendRequest(@NotNull WebSocketRequestMessage request, long timeoutSeconds) throws IOException {
if (client == null) {
throw new IOException("No connection!");
}
@@ -247,7 +248,7 @@ public class OkHttpWebSocketConnection extends WebSocketListener implements WebS
return single.subscribeOn(Schedulers.io())
.observeOn(Schedulers.io())
.timeout(10, TimeUnit.SECONDS, Schedulers.io());
.timeout(timeoutSeconds, TimeUnit.SECONDS, Schedulers.io());
}
@Override

View File

@@ -6,6 +6,7 @@ import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import java.io.IOException
import java.util.Optional
import java.util.concurrent.TimeoutException
import kotlin.time.Duration.Companion.seconds
/**
* Common interface for the web socket connection API
@@ -15,6 +16,10 @@ import java.util.concurrent.TimeoutException
* - LibSignalChatConnection - the wrapper around libsignal's [org.signal.libsignal.net.ChatService]
*/
interface WebSocketConnection {
companion object {
val DEFAULT_SEND_TIMEOUT = 10.seconds
}
val name: String
fun connect(): Observable<WebSocketConnectionState>
@@ -24,7 +29,12 @@ interface WebSocketConnection {
fun disconnect()
@Throws(IOException::class)
fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse>
fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
return sendRequest(request, DEFAULT_SEND_TIMEOUT.inWholeSeconds)
}
@Throws(IOException::class)
fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single<WebsocketResponse>
@Throws(IOException::class)
fun sendKeepAlive()