Improve auth WebSocket lifecycle.

This commit is contained in:
Cody Henthorne
2025-03-18 13:38:21 -04:00
committed by Alex Hart
parent 6bbd899507
commit 323697dfc9
16 changed files with 300 additions and 205 deletions

View File

@@ -78,7 +78,7 @@ class AccountApi(private val authWebSocket: SignalWebSocket.AuthenticatedWebSock
/**
* PUT /v1/accounts/registration_lock
* - 200: Success
* - 204: Success
*/
fun enableRegistrationLock(registrationLock: String): NetworkResult<Unit> {
val request = WebSocketRequestMessage.put("/v1/accounts/registration_lock", PushServiceSocket.RegistrationLockV2(registrationLock))

View File

@@ -17,6 +17,10 @@ public interface CredentialsProvider {
int getDeviceId();
String getPassword();
default boolean isInvalid() {
return (getAci() == null && getE164() == null) || getPassword() == null;
}
default String getUsername() {
StringBuilder sb = new StringBuilder();
sb.append(getAci().toString());

View File

@@ -16,6 +16,7 @@ import org.signal.core.util.logging.Log
import org.signal.core.util.orNull
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess
import org.whispersystems.signalservice.api.messages.EnvelopeResponse
import org.whispersystems.signalservice.api.util.SleepTimer
import org.whispersystems.signalservice.internal.push.Envelope
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
@@ -24,42 +25,56 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse
import java.io.IOException
import java.util.concurrent.TimeoutException
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
/**
* Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket
* interactions.
*/
sealed class SignalWebSocket(
private val createConnection: () -> WebSocketConnection
private val connectionFactory: WebSocketFactory,
val sleepTimer: SleepTimer,
private val disconnectTimeout: Duration
) {
companion object {
private val TAG = Log.tag(SignalWebSocket::class)
const val SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp"
const val FOREGROUND_KEEPALIVE = "Foregrounded"
/**
* Set to false to prevent web sockets from connecting. After setting back to true the caller
* must manually start the sockets again by calling [connect].
*/
@Volatile
@JvmStatic
var canConnect: Boolean = true
}
private var connection: WebSocketConnection? = null
private val connectionName
get() = connection?.name ?: "[null]"
private val _state: BehaviorSubject<WebSocketConnectionState> = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
protected var disposable: CompositeDisposable = CompositeDisposable()
private var canConnect = false
var shouldSendKeepAlives: Boolean = true
set(value) {
field = value
keepAliveChangedListener?.invoke()
}
private val keepAliveTokens: MutableSet<String> = mutableSetOf()
var keepAliveChangedListener: (() -> Unit)? = null
private var delayedDisconnectThread: DelayedDisconnectThread? = null
val state: Observable<WebSocketConnectionState> = _state
val stateSnapshot: WebSocketConnectionState
get() = _state.value!!
/**
* Indicate that WebSocketConnection can now be made and attempt to connect.
*/
@Synchronized
@Throws(WebSocketUnavailableException::class)
fun connect() {
canConnect = true
getWebSocket()
}
@@ -68,11 +83,6 @@ sealed class SignalWebSocket(
*/
@Synchronized
fun disconnect() {
canConnect = false
disconnectInternal()
}
private fun disconnectInternal() {
if (connection != null) {
disposable.dispose()
@@ -89,12 +99,53 @@ sealed class SignalWebSocket(
@Throws(IOException::class)
fun sendKeepAlive() {
if (canConnect) {
Log.v(TAG, "$connectionName keepAliveTokens: $keepAliveTokens")
getWebSocket().sendKeepAlive()
}
}
@Synchronized
fun shouldSendKeepAlives(): Boolean {
return keepAliveTokens.isNotEmpty()
}
@Synchronized
fun registerKeepAliveToken(token: String) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = null
val changed = keepAliveTokens.add(token)
if (changed) {
Log.v(TAG, "$connectionName Adding keepAliveToken: $token, current: $keepAliveTokens")
}
if (canConnect) {
try {
connect()
} catch (e: WebSocketUnavailableException) {
Log.w(TAG, "$connectionName Keep alive requested, but connection not available", e)
}
} else {
Log.w(TAG, "$connectionName Keep alive requested, but connection not available")
}
if (changed) {
keepAliveChangedListener?.invoke()
}
}
@Synchronized
fun removeKeepAliveToken(token: String) {
if (keepAliveTokens.remove(token)) {
Log.v(TAG, "$connectionName Removing keepAliveToken: $token, remaining: $keepAliveTokens")
startDelayedDisconnectIfNecessary()
keepAliveChangedListener?.invoke()
}
}
fun request(request: WebSocketRequestMessage): Single<WebsocketResponse> {
return try {
delayedDisconnectThread?.resetLastInteractionTime()
getWebSocket().sendRequest(request)
} catch (e: IOException) {
Single.error(e)
@@ -103,6 +154,7 @@ sealed class SignalWebSocket(
fun request(request: WebSocketRequestMessage, timeout: Duration): Single<WebsocketResponse> {
return try {
delayedDisconnectThread?.resetLastInteractionTime()
getWebSocket().sendRequest(request, timeout.inWholeSeconds)
} catch (e: IOException) {
Single.error(e)
@@ -125,7 +177,7 @@ sealed class SignalWebSocket(
disposable.dispose()
disposable = CompositeDisposable()
val newConnection = createConnection()
val newConnection = connectionFactory.createConnection()
newConnection
.connect()
@@ -135,15 +187,70 @@ sealed class SignalWebSocket(
.addTo(disposable)
this.connection = newConnection
startDelayedDisconnectIfNecessary()
}
return connection!!
}
private fun startDelayedDisconnectIfNecessary() {
if (connection.isAlive() && keepAliveTokens.isEmpty()) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = DelayedDisconnectThread().also { it.start() }
}
}
@Synchronized
fun forceNewWebSocket() {
Log.i(TAG, "Forcing new WebSockets connection: ${connection?.name ?: "[null]"} canConnect: $canConnect")
disconnectInternal()
Log.i(TAG, "$connectionName Forcing new WebSocket, canConnect: $canConnect")
disconnect()
}
/**
* Allow the WebSocket to self destruct if there are no keep alive tokens and it's been longer
* than [disconnectTimeout] since the last request was made.
*/
private inner class DelayedDisconnectThread : Thread() {
private var abort = false
@Volatile
private var lastInteractionTime = Duration.ZERO
fun abort() {
if (!abort && isAlive) {
Log.v(TAG, "$connectionName Scheduled disconnect aborted.")
abort = true
interrupt()
}
}
fun resetLastInteractionTime() {
lastInteractionTime = System.currentTimeMillis().milliseconds
}
override fun run() {
lastInteractionTime = System.currentTimeMillis().milliseconds
try {
while (!abort && (lastInteractionTime + disconnectTimeout) > System.currentTimeMillis().milliseconds) {
val now = System.currentTimeMillis().milliseconds
if (lastInteractionTime > now) {
lastInteractionTime = now
}
val sleepDuration = (lastInteractionTime + disconnectTimeout) - now
Log.v(TAG, "$connectionName Disconnect scheduled in $sleepDuration")
sleepTimer.sleep(sleepDuration.inWholeMilliseconds)
}
} catch (_: InterruptedException) { }
if (!abort && !shouldSendKeepAlives()) {
disconnect()
}
}
}
private fun WebSocketConnection?.isAlive(): Boolean {
return this?.isDead() == false
}
protected fun WebSocketRequestMessage.isSignalServiceEnvelope(): Boolean {
@@ -173,7 +280,7 @@ sealed class SignalWebSocket(
/**
* WebSocket type for communicating with the server without authenticating. Also known as "unidentified".
*/
class UnauthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) {
class UnauthenticatedWebSocket(connectionFactory: WebSocketFactory, sleepTimer: SleepTimer, disconnectTimeoutMs: Long) : SignalWebSocket(connectionFactory, sleepTimer, disconnectTimeoutMs.milliseconds) {
fun request(requestMessage: WebSocketRequestMessage, sealedSenderAccess: SealedSenderAccess): Single<WebsocketResponse> {
val headers: MutableList<String> = requestMessage.headers.toMutableList()
headers.add(sealedSenderAccess.header)
@@ -184,8 +291,7 @@ sealed class SignalWebSocket(
.build()
try {
return getWebSocket()
.sendRequest(message)
return request(message)
.flatMap<WebsocketResponse> { response ->
if (response.status == 401) {
val fallback = sealedSenderAccess.switchToFallback()
@@ -204,7 +310,7 @@ sealed class SignalWebSocket(
/**
* WebSocket type for communicating with the server with authentication. Also known as "identified".
*/
class AuthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) {
class AuthenticatedWebSocket(connectionFactory: WebSocketFactory, sleepTimer: SleepTimer, disconnectTimeoutMs: Long) : SignalWebSocket(connectionFactory, sleepTimer, disconnectTimeoutMs.milliseconds) {
/**
* The reads a batch of messages off of the websocket.

View File

@@ -3,6 +3,5 @@ package org.whispersystems.signalservice.api.websocket;
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection;
public interface WebSocketFactory {
WebSocketConnection createWebSocket();
WebSocketConnection createUnidentifiedWebSocket();
WebSocketConnection createConnection() throws WebSocketUnavailableException;
}

View File

@@ -4,11 +4,15 @@ import java.io.IOException;
/**
* Thrown when the WebSocket is not available for use by runtime policy. Currently, the
* WebSocket is only available when the app is in the foreground and requested via IncomingMessageObserver.
* Or, when using WebSocket Strategy.
* WebSocket is only unavailable when networking is blocked by a device transfer or if
* requesting to connect via auth but provide no auth credentials.
*/
public final class WebSocketUnavailableException extends IOException {
public WebSocketUnavailableException() {
super("WebSocket not currently available.");
}
public WebSocketUnavailableException(String reason) {
super("WebSocket not currently available. Reason: " + reason);
}
}