diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/username/UsernameApi.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/username/UsernameApi.kt index 070a20ad14..942b3bd281 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/username/UsernameApi.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/username/UsernameApi.kt @@ -31,7 +31,7 @@ class UsernameApi(private val unauthWebSocket: SignalWebSocket.UnauthenticatedWe */ fun getAciByUsername(username: Username): RequestResult { return runBlocking { - unauthWebSocket.runWithUnauthChatConnection { chatConnection -> + unauthWebSocket.runCatchingWithUnauthChatConnection { chatConnection -> UnauthUsernamesService(chatConnection).lookUpUsernameHash(username.hash) }.getOrError().map { it?.let { ServiceId.ACI.fromLibSignal(it) } } } diff --git a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt index 55044aa280..d88b9ce2dd 100644 --- a/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt +++ b/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt @@ -14,7 +14,10 @@ import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.subjects.BehaviorSubject import org.signal.core.util.logging.Log import org.signal.core.util.orNull -import org.signal.libsignal.net.ChatConnection +import org.signal.libsignal.internal.CompletableFuture +import org.signal.libsignal.net.BadRequestError +import org.signal.libsignal.net.RequestResult +import org.signal.libsignal.net.UnauthenticatedChatConnection import org.whispersystems.signalservice.api.crypto.SealedSenderAccess import org.whispersystems.signalservice.api.messages.EnvelopeResponse import org.whispersystems.signalservice.api.util.SleepTimer @@ -24,6 +27,8 @@ import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessa import org.whispersystems.signalservice.internal.websocket.WebSocketResponseMessage import org.whispersystems.signalservice.internal.websocket.WebsocketResponse import java.io.IOException +import java.util.concurrent.CancellationException +import java.util.concurrent.CompletionException import java.util.concurrent.CopyOnWriteArraySet import java.util.concurrent.TimeoutException import kotlin.time.Duration @@ -325,8 +330,26 @@ sealed class SignalWebSocket( } } - suspend fun runWithUnauthChatConnection(callback: (org.signal.libsignal.net.UnauthenticatedChatConnection) -> T): T { - return getWebSocket().runWithChatConnection(callback as (ChatConnection) -> T) + suspend fun runCatchingWithUnauthChatConnection( + callback: (UnauthenticatedChatConnection) -> CompletableFuture> + ): CompletableFuture> { + val requestFuture = try { + getWebSocket().runWithChatConnection { chatConnection -> + val unauthenticatedConnection = chatConnection as? UnauthenticatedChatConnection + ?: throw IllegalStateException("Expected unauthenticated chat connection but got ${chatConnection::class.java.simpleName}") + callback(unauthenticatedConnection) + } + } catch (throwable: Throwable) { + return CompletableFuture.completedFuture(throwable.toNetworkRequestResult()) + } + + return requestFuture.handle { result, throwable -> + when { + throwable != null -> throwable.toNetworkRequestResult() + result != null -> result + else -> RequestResult.ApplicationError(IllegalStateException("RequestResult was null")) + } + } } } @@ -447,3 +470,16 @@ sealed class SignalWebSocket( fun canConnect(): Boolean } } + +private fun Throwable.toNetworkRequestResult(): RequestResult { + val cause = if (this is CompletionException && this.cause != null) { + this.cause!! + } else { + this + } + return when (cause) { + is IOException -> RequestResult.RetryableNetworkError(cause) + is CancellationException -> RequestResult.RetryableNetworkError(IOException("Request cancelled", cause)) + else -> RequestResult.ApplicationError(cause) + } +}