diff --git a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java index 696410e3ca..d94086961a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java +++ b/app/src/main/java/org/thoughtcrime/securesms/dependencies/ApplicationDependencyProvider.java @@ -342,7 +342,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { @Override public @NonNull SignalWebSocket.AuthenticatedWebSocket provideAuthWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); - SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); + SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer, true); WebSocketFactory authFactory = () -> { DynamicCredentialsProvider credentialsProvider = new DynamicCredentialsProvider(); @@ -375,7 +375,7 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider { @Override public @NonNull SignalWebSocket.UnauthenticatedWebSocket provideUnauthWebSocket(@NonNull Supplier signalServiceConfigurationSupplier, @NonNull Supplier libSignalNetworkSupplier) { SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); - SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); + SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer, false); WebSocketFactory unauthFactory = () -> { Network network = libSignalNetworkSupplier.get(); diff --git a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt index 4ca847964b..913992870a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/net/SignalWebSocketHealthMonitor.kt @@ -10,13 +10,13 @@ import io.reactivex.rxjava3.schedulers.Schedulers import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.delay import kotlinx.coroutines.launch import org.signal.core.util.logging.Log import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.keyvalue.SignalStore -import org.thoughtcrime.securesms.net.SignalWebSocketHealthMonitor.Companion.KEEP_ALIVE_SEND_CADENCE -import org.thoughtcrime.securesms.net.SignalWebSocketHealthMonitor.Companion.KEEP_ALIVE_TIMEOUT +import org.thoughtcrime.securesms.util.AppForegroundObserver import org.thoughtcrime.securesms.util.TextSecurePreferences import org.whispersystems.signalservice.api.util.SleepTimer import org.whispersystems.signalservice.api.websocket.HealthMonitor @@ -30,22 +30,15 @@ import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds class SignalWebSocketHealthMonitor( - private val sleepTimer: SleepTimer + private val sleepTimer: SleepTimer, + private val sendKeepAlives: Boolean = true ) : HealthMonitor { companion object { private val TAG = Log.tag(SignalWebSocketHealthMonitor::class) - /** - * This is the amount of time in between sent keep alives. Must be greater than [KEEP_ALIVE_TIMEOUT] - */ private val KEEP_ALIVE_SEND_CADENCE: Duration = OkHttpWebSocketConnection.KEEPALIVE_FREQUENCY_SECONDS.seconds - - /** - * This is the amount of time we will wait for a response to the keep alive before we consider the websockets dead. - * It is required that this value be less than [KEEP_ALIVE_SEND_CADENCE] - */ - private val KEEP_ALIVE_TIMEOUT: Duration = 20.seconds + private val KEEP_ALIVE_SEND_CADENCE_BACKGROUND: Duration = 60.seconds } private val executor: Executor = Executors.newSingleThreadExecutor() @@ -56,7 +49,7 @@ class SignalWebSocketHealthMonitor( private var needsKeepAlive = false private var lastKeepAliveReceived: Duration = 0.seconds - private val scope = CoroutineScope(Dispatchers.IO) + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) private var connectingTimeoutJob: Job? = null private var failedInConnecting: Boolean = false @@ -74,7 +67,9 @@ class SignalWebSocketHealthMonitor( .distinctUntilChanged() .subscribeBy { onStateChanged(it) } - webSocket.addKeepAliveChangeListener { executor.execute(this::updateKeepAliveSenderStatus) } + if (sendKeepAlives) { + webSocket.addKeepAliveChangeListener { executor.execute(this::updateKeepAliveSenderStatus) } + } } } @@ -111,7 +106,7 @@ class SignalWebSocketHealthMonitor( else -> Unit } - needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED + needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED && sendKeepAlives if (connectionState != WebSocketConnectionState.CONNECTING) { connectingTimeoutJob?.let { @@ -163,7 +158,7 @@ class SignalWebSocketHealthMonitor( /** * Sends periodic heartbeats/keep-alives over the WebSocket to prevent connection timeouts. If - * the WebSocket fails to get a return heartbeat after [KEEP_ALIVE_TIMEOUT] seconds, it is forced to be recreated. + * the WebSocket fails to get a return heartbeat before the next keep alive is sent, it is forced to be recreated. */ private inner class KeepAliveSender : Thread() { @@ -178,11 +173,12 @@ class SignalWebSocketHealthMonitor( var hasSentKeepAlive = false while (shouldKeepRunning && sendKeepAlives()) { try { - sleepUntil(keepAliveSentTime + KEEP_ALIVE_SEND_CADENCE) + val cadence = if (AppForegroundObserver.isForegrounded()) KEEP_ALIVE_SEND_CADENCE else KEEP_ALIVE_SEND_CADENCE_BACKGROUND + sleepUntil(keepAliveSentTime + cadence) if (shouldKeepRunning && sendKeepAlives()) { if (hasSentKeepAlive && lastKeepAliveReceived < keepAliveSentTime) { - Log.w(TAG, "Missed keep alive, last: ${lastKeepAliveReceived.inWholeMilliseconds} needed by: ${(keepAliveSentTime + KEEP_ALIVE_TIMEOUT).inWholeMilliseconds}") + Log.w(TAG, "Missed keep alive, last: ${lastKeepAliveReceived.inWholeMilliseconds} needed by: ${keepAliveSentTime.inWholeMilliseconds}") webSocket?.forceNewWebSocket() } @@ -190,6 +186,8 @@ class SignalWebSocketHealthMonitor( webSocket?.sendKeepAlive() hasSentKeepAlive = true } + } catch (e: InterruptedException) { + // Stopped } catch (e: Throwable) { Log.w(TAG, e) } @@ -198,7 +196,7 @@ class SignalWebSocketHealthMonitor( } fun sleepUntil(time: Duration) { - while (System.currentTimeMillis().milliseconds < time) { + while (shouldKeepRunning && System.currentTimeMillis().milliseconds < time) { val waitTime = time - System.currentTimeMillis().milliseconds if (waitTime.isPositive()) { try { @@ -212,6 +210,7 @@ class SignalWebSocketHealthMonitor( fun shutdown() { shouldKeepRunning = false + interrupt() } } } diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt index d88b9ce2dd..5d1d1bd57c 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/api/websocket/SignalWebSocket.kt @@ -12,6 +12,12 @@ import io.reactivex.rxjava3.kotlin.addTo import io.reactivex.rxjava3.kotlin.subscribeBy import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.subjects.BehaviorSubject +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import org.signal.core.util.logging.Log import org.signal.core.util.orNull import org.signal.libsignal.internal.CompletableFuture @@ -65,7 +71,8 @@ sealed class SignalWebSocket( private val keepAliveTokens: MutableSet = CopyOnWriteArraySet() private val keepAliveChangeListeners: MutableSet = CopyOnWriteArraySet() - private var delayedDisconnectThread: DelayedDisconnectThread? = null + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + private var delayedDisconnectJob: Job? = null val state: Observable = _state val stateSnapshot: WebSocketConnectionState @@ -88,7 +95,7 @@ sealed class SignalWebSocket( if (connection != null) { disposable.dispose() - connection!!.disconnect() + connection!!.shutdown() connection = null if (!_state.value!!.isFailure) { @@ -117,8 +124,8 @@ sealed class SignalWebSocket( } synchronized(this) { - delayedDisconnectThread?.abort() - delayedDisconnectThread = null + delayedDisconnectJob?.cancel() + delayedDisconnectJob = null if (canConnect.canConnect()) { try { @@ -153,7 +160,7 @@ sealed class SignalWebSocket( fun request(request: WebSocketRequestMessage): Single { return try { - delayedDisconnectThread?.resetLastInteractionTime() + restartDelayedDisconnectIfNecessary() getWebSocket().sendRequest(request) } catch (e: IOException) { Single.error(e) @@ -162,7 +169,7 @@ sealed class SignalWebSocket( fun request(request: WebSocketRequestMessage, timeout: Duration): Single { return try { - delayedDisconnectThread?.resetLastInteractionTime() + restartDelayedDisconnectIfNecessary() getWebSocket().sendRequest(request, timeout.inWholeSeconds) } catch (e: IOException) { Single.error(e) @@ -194,6 +201,7 @@ sealed class SignalWebSocket( } if (connection == null || connection?.isDead() == true) { + connection?.shutdown() disposable.dispose() disposable = CompositeDisposable() @@ -216,8 +224,22 @@ sealed class SignalWebSocket( private fun startDelayedDisconnectIfNecessary() { if (connection.isAlive() && keepAliveTokens.isEmpty()) { - delayedDisconnectThread?.abort() - delayedDisconnectThread = DelayedDisconnectThread().also { it.start() } + delayedDisconnectJob?.cancel() + delayedDisconnectJob = scope.launch { + Log.v(TAG, "$connectionName Disconnect scheduled in $disconnectTimeout") + delay(disconnectTimeout) + if (!shouldSendKeepAlives()) { + disconnect() + } + } + } + } + + private fun restartDelayedDisconnectIfNecessary() { + synchronized(this) { + if (delayedDisconnectJob?.isActive == true) { + startDelayedDisconnectIfNecessary() + } } } @@ -227,50 +249,6 @@ sealed class SignalWebSocket( disconnect() } - /** - * Allow the WebSocket to self destruct if there are no keep alive tokens and it's been longer - * than [disconnectTimeout] since the last request was made. - */ - private inner class DelayedDisconnectThread : Thread() { - private var abort = false - - @Volatile - private var lastInteractionTime = Duration.ZERO - - fun abort() { - if (!abort && isAlive) { - Log.v(TAG, "$connectionName Scheduled disconnect aborted.") - abort = true - interrupt() - } - } - - fun resetLastInteractionTime() { - lastInteractionTime = System.currentTimeMillis().milliseconds - } - - override fun run() { - lastInteractionTime = System.currentTimeMillis().milliseconds - try { - while (!abort && (lastInteractionTime + disconnectTimeout) > System.currentTimeMillis().milliseconds) { - val now = System.currentTimeMillis().milliseconds - if (lastInteractionTime > now) { - lastInteractionTime = now - } - val sleepDuration = (lastInteractionTime + disconnectTimeout) - now - if (sleepDuration.isPositive()) { - Log.v(TAG, "$connectionName Disconnect scheduled in $sleepDuration") - sleepTimer.sleep(sleepDuration.inWholeMilliseconds) - } - } - } catch (_: InterruptedException) { } - - if (!abort && !shouldSendKeepAlives()) { - disconnect() - } - } - } - private fun WebSocketConnection?.isAlive(): Boolean { return this?.isDead() == false } diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt index 7bfa104053..54fdcb6fe5 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/LibSignalChatConnection.kt @@ -369,6 +369,11 @@ class LibSignalChatConnection( } } + override fun shutdown() { + disconnect() + listener.shutdown() + } + override fun sendRequest(request: WebSocketRequestMessage, timeoutSeconds: Long): Single { CHAT_SERVICE_LOCK.withLock { if (isDead()) { @@ -605,6 +610,10 @@ class LibSignalChatConnection( private inner class LibSignalChatListener : ChatConnectionListener { private val executor = Executors.newSingleThreadExecutor() + fun shutdown() { + executor.shutdown() + } + override fun onIncomingMessage(chat: ChatConnection, envelope: ByteArray, serverDeliveryTimestamp: Long, sendAck: ChatConnectionListener.ServerMessageAck?) { // NB: The order here is intentional to ensure concurrency-safety, so that when a request is pulled off the queue, its sendAck is // already in the ackSender map, if it exists. diff --git a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt index 5977eceaf3..79aae9cae4 100644 --- a/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt +++ b/lib/libsignal-service/src/main/java/org/whispersystems/signalservice/internal/websocket/WebSocketConnection.kt @@ -28,6 +28,13 @@ interface WebSocketConnection { fun disconnect() + /** + * Unlike [disconnect], this connection should not be reused after calling this method. + */ + fun shutdown() { + disconnect() + } + @Throws(IOException::class) fun sendRequest(request: WebSocketRequestMessage): Single { return sendRequest(request, DEFAULT_SEND_TIMEOUT.inWholeSeconds)