Improve web socket behaviors around keep alive and shutdown.

This commit is contained in:
Cody Henthorne
2026-04-15 14:31:20 -04:00
committed by jeffrey-signal
parent 3804890265
commit a797bbf850
5 changed files with 66 additions and 73 deletions
@@ -342,7 +342,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
@Override
public @NonNull SignalWebSocket.AuthenticatedWebSocket provideAuthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) {
SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer();
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer);
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer, true);
WebSocketFactory authFactory = () -> {
DynamicCredentialsProvider credentialsProvider = new DynamicCredentialsProvider();
@@ -375,7 +375,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
@Override
public @NonNull SignalWebSocket.UnauthenticatedWebSocket provideUnauthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) {
SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer();
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer);
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer, false);
WebSocketFactory unauthFactory = () -> {
Network network = libSignalNetworkSupplier.get();
@@ -10,13 +10,13 @@ import io.reactivex.rxjava3.schedulers.Schedulers
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.keyvalue.SignalStore
import org.thoughtcrime.securesms.net.SignalWebSocketHealthMonitor.Companion.KEEP_ALIVE_SEND_CADENCE
import org.thoughtcrime.securesms.net.SignalWebSocketHealthMonitor.Companion.KEEP_ALIVE_TIMEOUT
import org.thoughtcrime.securesms.util.AppForegroundObserver
import org.thoughtcrime.securesms.util.TextSecurePreferences
import org.whispersystems.signalservice.api.util.SleepTimer
import org.whispersystems.signalservice.api.websocket.HealthMonitor
@@ -30,22 +30,15 @@ import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
class SignalWebSocketHealthMonitor(
private val sleepTimer: SleepTimer
private val sleepTimer: SleepTimer,
private val sendKeepAlives: Boolean = true
) : HealthMonitor {
companion object {
private val TAG = Log.tag(SignalWebSocketHealthMonitor::class)
/**
* This is the amount of time in between sent keep alives. Must be greater than [KEEP_ALIVE_TIMEOUT]
*/
private val KEEP_ALIVE_SEND_CADENCE: Duration = OkHttpWebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS.seconds
/**
* This is the amount of time we will wait for a response to the keep alive before we consider the websockets dead.
* It is required that this value be less than [KEEP_ALIVE_SEND_CADENCE]
*/
private val KEEP_ALIVE_TIMEOUT: Duration = 20.seconds
private val KEEP_ALIVE_SEND_CADENCE_BACKGROUND: Duration = 60.seconds
}
private val executor: Executor = Executors.newSingleThreadExecutor()
@@ -56,7 +49,7 @@ class SignalWebSocketHealthMonitor(
private var needsKeepAlive = false
private var lastKeepAliveReceived: Duration = 0.seconds
private val scope = CoroutineScope(Dispatchers.IO)
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private var connectingTimeoutJob: Job? = null
private var failedInConnecting: Boolean = false
@@ -74,7 +67,9 @@ class SignalWebSocketHealthMonitor(
.distinctUntilChanged()
.subscribeBy { onStateChanged(it) }
webSocket.addKeepAliveChangeListener { executor.execute(this::updateKeepAliveSenderStatus) }
if (sendKeepAlives) {
webSocket.addKeepAliveChangeListener { executor.execute(this::updateKeepAliveSenderStatus) }
}
}
}
@@ -111,7 +106,7 @@ class SignalWebSocketHealthMonitor(
else -> Unit
}
needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED
needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED && sendKeepAlives
if (connectionState != WebSocketConnectionState.CONNECTING) {
connectingTimeoutJob?.let {
@@ -163,7 +158,7 @@ class SignalWebSocketHealthMonitor(
/**
* Sends periodic heartbeats/keep-alives over the WebSocket to prevent connection timeouts. If
* the WebSocket fails to get a return heartbeat after [KEEP_ALIVE_TIMEOUT] seconds, it is forced to be recreated.
* the WebSocket fails to get a return heartbeat before the next keep alive is sent, it is forced to be recreated.
*/
private inner class KeepAliveSender : Thread() {
@@ -178,11 +173,12 @@ class SignalWebSocketHealthMonitor(
var hasSentKeepAlive = false
while (shouldKeepRunning && sendKeepAlives()) {
try {
sleepUntil(keepAliveSentTime + KEEP_ALIVE_SEND_CADENCE)
val cadence = if (AppForegroundObserver.isForegrounded()) KEEP_ALIVE_SEND_CADENCE else KEEP_ALIVE_SEND_CADENCE_BACKGROUND
sleepUntil(keepAliveSentTime + cadence)
if (shouldKeepRunning && sendKeepAlives()) {
if (hasSentKeepAlive && lastKeepAliveReceived < keepAliveSentTime) {
Log.w(TAG, "Missed keep alive, last: ${lastKeepAliveReceived.inWholeMilliseconds} needed by: ${(keepAliveSentTime + KEEP_ALIVE_TIMEOUT).inWholeMilliseconds}")
Log.w(TAG, "Missed keep alive, last: ${lastKeepAliveReceived.inWholeMilliseconds} needed by: ${keepAliveSentTime.inWholeMilliseconds}")
webSocket?.forceNewWebSocket()
}
@@ -190,6 +186,8 @@ class SignalWebSocketHealthMonitor(
webSocket?.sendKeepAlive()
hasSentKeepAlive = true
}
} catch (e: InterruptedException) {
// Stopped
} catch (e: Throwable) {
Log.w(TAG, e)
}
@@ -198,7 +196,7 @@ class SignalWebSocketHealthMonitor(
}
fun sleepUntil(time: Duration) {
while (System.currentTimeMillis().milliseconds < time) {
while (shouldKeepRunning && System.currentTimeMillis().milliseconds < time) {
val waitTime = time - System.currentTimeMillis().milliseconds
if (waitTime.isPositive()) {
try {
@@ -212,6 +210,7 @@ class SignalWebSocketHealthMonitor(
fun shutdown() {
shouldKeepRunning = false
interrupt()
}
}
}
@@ -12,6 +12,12 @@ import io.reactivex.rxjava3.kotlin.addTo
import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.schedulers.Schedulers
import io.reactivex.rxjava3.subjects.BehaviorSubject
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import org.signal.core.util.logging.Log
import org.signal.core.util.orNull
import org.signal.libsignal.internal.CompletableFuture
@@ -65,7 +71,8 @@ sealed class SignalWebSocket(
private val keepAliveTokens: MutableSet<String> = CopyOnWriteArraySet()
private val keepAliveChangeListeners: MutableSet<Listener> = CopyOnWriteArraySet()
private var delayedDisconnectThread: DelayedDisconnectThread? = null
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private var delayedDisconnectJob: Job? = null
val state: Observable<WebSocketConnectionState> = _state
val stateSnapshot: WebSocketConnectionState
@@ -88,7 +95,7 @@ sealed class SignalWebSocket(
if (connection != null) {
disposable.dispose()
connection!!.disconnect()
connection!!.shutdown()
connection = null
if (!_state.value!!.isFailure) {
@@ -117,8 +124,8 @@ sealed class SignalWebSocket(
}
synchronized(this) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = null
delayedDisconnectJob?.cancel()
delayedDisconnectJob = null
if (canConnect.canConnect()) {
try {
@@ -153,7 +160,7 @@ sealed class SignalWebSocket(
fun request(request: WebSocketRequestMessage): Single<WebsocketResponse> {
return try {
delayedDisconnectThread?.resetLastInteractionTime()
restartDelayedDisconnectIfNecessary()
getWebSocket().sendRequest(request)
} catch (e: IOException) {
Single.error(e)
@@ -162,7 +169,7 @@ sealed class SignalWebSocket(
fun request(request: WebSocketRequestMessage, timeout: Duration): Single<WebsocketResponse> {
return try {
delayedDisconnectThread?.resetLastInteractionTime()
restartDelayedDisconnectIfNecessary()
getWebSocket().sendRequest(request, timeout.inWholeSeconds)
} catch (e: IOException) {
Single.error(e)
@@ -194,6 +201,7 @@ sealed class SignalWebSocket(
}
if (connection == null || connection?.isDead() == true) {
connection?.shutdown()
disposable.dispose()
disposable = CompositeDisposable()
@@ -216,8 +224,22 @@ sealed class SignalWebSocket(
private fun startDelayedDisconnectIfNecessary() {
if (connection.isAlive() && keepAliveTokens.isEmpty()) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = DelayedDisconnectThread().also { it.start() }
delayedDisconnectJob?.cancel()
delayedDisconnectJob = scope.launch {
Log.v(TAG, "$connectionName Disconnect scheduled in $disconnectTimeout")
delay(disconnectTimeout)
if (!shouldSendKeepAlives()) {
disconnect()
}
}
}
}
private fun restartDelayedDisconnectIfNecessary() {
synchronized(this) {
if (delayedDisconnectJob?.isActive == true) {
startDelayedDisconnectIfNecessary()
}
}
}
@@ -227,50 +249,6 @@ sealed class SignalWebSocket(
disconnect()
}
/**
* Allow the WebSocket to self destruct if there are no keep alive tokens and it's been longer
* than [disconnectTimeout] since the last request was made.
*/
private inner class DelayedDisconnectThread : Thread() {
private var abort = false
@Volatile
private var lastInteractionTime = Duration.ZERO
fun abort() {
if (!abort && isAlive) {
Log.v(TAG, "$connectionName Scheduled disconnect aborted.")
abort = true
interrupt()
}
}
fun resetLastInteractionTime() {
lastInteractionTime = System.currentTimeMillis().milliseconds
}
override fun run() {
lastInteractionTime = System.currentTimeMillis().milliseconds
try {
while (!abort && (lastInteractionTime + disconnectTimeout) > System.currentTimeMillis().milliseconds) {
val now = System.currentTimeMillis().milliseconds
if (lastInteractionTime > now) {
lastInteractionTime = now
}
val sleepDuration = (lastInteractionTime + disconnectTimeout) - now
if (sleepDuration.isPositive()) {
Log.v(TAG, "$connectionName Disconnect scheduled in $sleepDuration")
sleepTimer.sleep(sleepDuration.inWholeMilliseconds)
}
}
} catch (_: InterruptedException) { }
if (!abort && !shouldSendKeepAlives()) {
disconnect()
}
}
}
private fun WebSocketConnection?.isAlive(): Boolean {
return this?.isDead() == false
}
@@ -369,6 +369,11 @@ class LibSignalChatConnection(
}
}
override fun shutdown() {
disconnect()
listener.shutdown()
}
override fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single<WebsocketResponse> {
CHAT_SERVICE_LOCK.withLock {
if (isDead()) {
@@ -605,6 +610,10 @@ class LibSignalChatConnection(
private inner class LibSignalChatListener : ChatConnectionListener {
private val executor = Executors.newSingleThreadExecutor()
fun shutdown() {
executor.shutdown()
}
override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) {
// NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is
// already in the ackSender map, if it exists.
@@ -28,6 +28,13 @@ interface WebSocketConnection {
fun disconnect()
/**
* Unlike [disconnect], this connection should not be reused after calling this method.
*/
fun shutdown() {
disconnect()
}
@Throws(IOException::class)
fun sendRequest(request: WebSocketRequestMessage): Single<WebsocketResponse> {
return sendRequest(request, DEFAULT_SEND_TIMEOUT.inWholeSeconds)