Improve auth WebSocket lifecycle.

This commit is contained in:
Cody Henthorne
2025-03-18 13:38:21 -04:00
committed by Alex Hart
parent 6bbd899507
commit 323697dfc9
16 changed files with 300 additions and 205 deletions

View File

@@ -110,6 +110,7 @@ import org.thoughtcrime.securesms.util.TextSecurePreferences;
import org.thoughtcrime.securesms.util.Util; import org.thoughtcrime.securesms.util.Util;
import org.thoughtcrime.securesms.util.VersionTracker; import org.thoughtcrime.securesms.util.VersionTracker;
import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper; import org.thoughtcrime.securesms.util.dynamiclanguage.DynamicLanguageContextWrapper;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import java.io.InterruptedIOException; import java.io.InterruptedIOException;
import java.net.SocketException; import java.net.SocketException;
@@ -259,7 +260,8 @@ public class ApplicationContext extends Application implements AppForegroundObse
checkFreeDiskSpace(); checkFreeDiskSpace();
MemoryTracker.start(); MemoryTracker.start();
BackupSubscriptionCheckJob.enqueueIfAble(); BackupSubscriptionCheckJob.enqueueIfAble();
AppDependencies.getUnauthWebSocket().setShouldSendKeepAlives(true); AppDependencies.getAuthWebSocket().registerKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
AppDependencies.getUnauthWebSocket().registerKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
long lastForegroundTime = SignalStore.misc().getLastForegroundTime(); long lastForegroundTime = SignalStore.misc().getLastForegroundTime();
long currentTime = System.currentTimeMillis(); long currentTime = System.currentTimeMillis();
@@ -283,7 +285,8 @@ public class ApplicationContext extends Application implements AppForegroundObse
AppDependencies.getFrameRateTracker().stop(); AppDependencies.getFrameRateTracker().stop();
AppDependencies.getShakeToReport().disable(); AppDependencies.getShakeToReport().disable();
AppDependencies.getDeadlockDetector().stop(); AppDependencies.getDeadlockDetector().stop();
AppDependencies.getUnauthWebSocket().setShouldSendKeepAlives(false); AppDependencies.getAuthWebSocket().removeKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
AppDependencies.getUnauthWebSocket().removeKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
MemoryTracker.stop(); MemoryTracker.stop();
AnrDetector.stop(); AnrDetector.stop();
} }

View File

@@ -67,6 +67,7 @@ import org.thoughtcrime.securesms.service.webrtc.SignalCallManager;
import org.thoughtcrime.securesms.shakereport.ShakeToReport; import org.thoughtcrime.securesms.shakereport.ShakeToReport;
import org.thoughtcrime.securesms.stories.Stories; import org.thoughtcrime.securesms.stories.Stories;
import org.thoughtcrime.securesms.util.AlarmSleepTimer; import org.thoughtcrime.securesms.util.AlarmSleepTimer;
import org.thoughtcrime.securesms.util.AppForegroundObserver;
import org.thoughtcrime.securesms.util.ByteUnit; import org.thoughtcrime.securesms.util.ByteUnit;
import org.thoughtcrime.securesms.util.EarlyMessageCache; import org.thoughtcrime.securesms.util.EarlyMessageCache;
import org.thoughtcrime.securesms.util.Environment; import org.thoughtcrime.securesms.util.Environment;
@@ -102,15 +103,14 @@ import org.whispersystems.signalservice.api.username.UsernameApi;
import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.CredentialsProvider;
import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.SleepTimer;
import org.whispersystems.signalservice.api.util.UptimeSleepTimer; import org.whispersystems.signalservice.api.util.UptimeSleepTimer;
import org.whispersystems.signalservice.api.websocket.HealthMonitor;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket; import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import org.whispersystems.signalservice.api.websocket.WebSocketFactory; import org.whispersystems.signalservice.api.websocket.WebSocketFactory;
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException;
import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration; import org.whispersystems.signalservice.internal.configuration.SignalServiceConfiguration;
import org.whispersystems.signalservice.internal.push.PushServiceSocket; import org.whispersystems.signalservice.internal.push.PushServiceSocket;
import org.whispersystems.signalservice.internal.websocket.LibSignalChatConnection; import org.whispersystems.signalservice.internal.websocket.LibSignalChatConnection;
import org.whispersystems.signalservice.internal.websocket.LibSignalNetworkExtensions; import org.whispersystems.signalservice.internal.websocket.LibSignalNetworkExtensions;
import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection; import org.whispersystems.signalservice.internal.websocket.OkHttpWebSocketConnection;
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -303,10 +303,37 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
@Override @Override
public @NonNull SignalWebSocket.AuthenticatedWebSocket provideAuthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) { public @NonNull SignalWebSocket.AuthenticatedWebSocket provideAuthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) {
SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer();
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer);
WebSocketFactory webSocketFactory = provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier);
SignalWebSocket.AuthenticatedWebSocket webSocket = new SignalWebSocket.AuthenticatedWebSocket(webSocketFactory::createWebSocket); WebSocketFactory authFactory = () -> {
DynamicCredentialsProvider credentialsProvider = new DynamicCredentialsProvider();
if (credentialsProvider.isInvalid()) {
throw new WebSocketUnavailableException("Invalid auth credentials");
}
if (RemoteConfig.libSignalWebSocketEnabled()) {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection("libsignal-auth",
network,
credentialsProvider,
Stories.isFeatureEnabled(),
healthMonitor);
} else {
return new OkHttpWebSocketConnection("auth",
signalServiceConfigurationSupplier.get(),
Optional.of(credentialsProvider),
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled());
}
};
SignalWebSocket.AuthenticatedWebSocket webSocket = new SignalWebSocket.AuthenticatedWebSocket(authFactory, sleepTimer, TimeUnit.SECONDS.toMillis(10));
if (AppForegroundObserver.isForegrounded()) {
webSocket.registerKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
}
healthMonitor.monitor(webSocket); healthMonitor.monitor(webSocket);
@@ -315,13 +342,33 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
@Override @Override
public @NonNull SignalWebSocket.UnauthenticatedWebSocket provideUnauthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) { public @NonNull SignalWebSocket.UnauthenticatedWebSocket provideUnauthWebSocket(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier, @NonNull Supplier<Network> libSignalNetworkSupplier) {
SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer(); SleepTimer sleepTimer = !SignalStore.account().isFcmEnabled() || SignalStore.internal().isWebsocketModeForced() ? new AlarmSleepTimer(context) : new UptimeSleepTimer();
SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer); SignalWebSocketHealthMonitor healthMonitor = new SignalWebSocketHealthMonitor(sleepTimer);
WebSocketFactory webSocketFactory = provideWebSocketFactory(signalServiceConfigurationSupplier, healthMonitor, libSignalNetworkSupplier);
SignalWebSocket.UnauthenticatedWebSocket webSocket = new SignalWebSocket.UnauthenticatedWebSocket(webSocketFactory::createUnidentifiedWebSocket); WebSocketFactory unauthFactory = () -> {
if (RemoteConfig.libSignalWebSocketEnabled()) {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection("libsignal-unauth",
network,
null,
Stories.isFeatureEnabled(),
healthMonitor);
} else {
return new OkHttpWebSocketConnection("unauth",
signalServiceConfigurationSupplier.get(),
Optional.empty(),
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled());
}
};
SignalWebSocket.UnauthenticatedWebSocket webSocket = new SignalWebSocket.UnauthenticatedWebSocket(unauthFactory, sleepTimer, TimeUnit.SECONDS.toMillis(10));
if (AppForegroundObserver.isForegrounded()) {
webSocket.registerKeepAliveToken(SignalWebSocket.FOREGROUND_KEEPALIVE);
}
healthMonitor.monitor(webSocket); healthMonitor.monitor(webSocket);
return webSocket; return webSocket;
} }
@@ -413,51 +460,6 @@ public class ApplicationDependencyProvider implements AppDependencies.Provider {
return provideClientZkOperations(signalServiceConfiguration).getReceiptOperations(); return provideClientZkOperations(signalServiceConfiguration).getReceiptOperations();
} }
@NonNull WebSocketFactory provideWebSocketFactory(@NonNull Supplier<SignalServiceConfiguration> signalServiceConfigurationSupplier,
@NonNull HealthMonitor healthMonitor,
@NonNull Supplier<Network> libSignalNetworkSupplier)
{
return new WebSocketFactory() {
@Override
public WebSocketConnection createWebSocket() {
if (RemoteConfig.libSignalWebSocketEnabled()) {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection("libsignal-auth",
network,
new DynamicCredentialsProvider(),
Stories.isFeatureEnabled(),
healthMonitor);
} else {
return new OkHttpWebSocketConnection("normal",
signalServiceConfigurationSupplier.get(),
Optional.of(new DynamicCredentialsProvider()),
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled());
}
}
@Override
public WebSocketConnection createUnidentifiedWebSocket() {
if (RemoteConfig.libSignalWebSocketEnabled()) {
Network network = libSignalNetworkSupplier.get();
return new LibSignalChatConnection("libsignal-unauth",
network,
null,
Stories.isFeatureEnabled(),
healthMonitor);
} else {
return new OkHttpWebSocketConnection("unidentified",
signalServiceConfigurationSupplier.get(),
Optional.empty(),
BuildConfig.SIGNAL_AGENT,
healthMonitor,
Stories.isFeatureEnabled());
}
}
};
}
@Override @Override
public @NonNull BillingApi provideBillingApi() { public @NonNull BillingApi provideBillingApi() {
return BillingFactory.create(GooglePlayBillingDependencies.INSTANCE, RemoteConfig.messageBackups() && Environment.Backups.supportsGooglePlayBilling()); return BillingFactory.create(GooglePlayBillingDependencies.INSTANCE, RemoteConfig.messageBackups() && Environment.Backups.supportsGooglePlayBilling());

View File

@@ -11,6 +11,7 @@ import io.reactivex.rxjava3.kotlin.plusAssign
import io.reactivex.rxjava3.subjects.Subject import io.reactivex.rxjava3.subjects.Subject
import okhttp3.ConnectionSpec import okhttp3.ConnectionSpec
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import org.signal.core.util.logging.Log
import org.signal.core.util.resettableLazy import org.signal.core.util.resettableLazy
import org.signal.libsignal.net.Network import org.signal.libsignal.net.Network
import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations
@@ -46,6 +47,7 @@ import org.whispersystems.signalservice.api.username.UsernameApi
import org.whispersystems.signalservice.api.util.Tls12SocketFactory import org.whispersystems.signalservice.api.util.Tls12SocketFactory
import org.whispersystems.signalservice.api.websocket.SignalWebSocket import org.whispersystems.signalservice.api.websocket.SignalWebSocket
import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException
import org.whispersystems.signalservice.internal.push.PushServiceSocket import org.whispersystems.signalservice.internal.push.PushServiceSocket
import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager
import org.whispersystems.signalservice.internal.util.Util import org.whispersystems.signalservice.internal.util.Util
@@ -64,6 +66,10 @@ class NetworkDependenciesModule(
private val webSocketStateSubject: Subject<WebSocketConnectionState> private val webSocketStateSubject: Subject<WebSocketConnectionState>
) { ) {
companion object {
private val TAG = "NetworkDependencies"
}
private val disposables: CompositeDisposable = CompositeDisposable() private val disposables: CompositeDisposable = CompositeDisposable()
val signalServiceNetworkAccess: SignalServiceNetworkAccess by lazy { val signalServiceNetworkAccess: SignalServiceNetworkAccess by lazy {
@@ -215,6 +221,7 @@ class NetworkDependenciesModule(
} }
fun closeConnections() { fun closeConnections() {
Log.i(TAG, "Closing connections.")
incomingMessageObserver.terminateAsync() incomingMessageObserver.terminateAsync()
if (_signalServiceMessageSender.isInitialized()) { if (_signalServiceMessageSender.isInitialized()) {
signalServiceMessageSender.cancelInFlightRequests() signalServiceMessageSender.cancelInFlightRequests()
@@ -224,8 +231,19 @@ class NetworkDependenciesModule(
} }
fun openConnections() { fun openConnections() {
try {
authWebSocket.connect()
} catch (e: WebSocketUnavailableException) {
Log.w(TAG, "Not allowed to start auth websocket", e)
}
try {
unauthWebSocket.connect()
} catch (e: WebSocketUnavailableException) {
Log.w(TAG, "Not allowed to start unauth websocket", e)
}
incomingMessageObserver incomingMessageObserver
unauthWebSocket.connect()
} }
fun resetProtocolStores() { fun resetProtocolStores() {

View File

@@ -45,7 +45,14 @@ object FcmFetchManager {
private val KEEP_ALIVE_TOKEN = "FcmFetch" private val KEEP_ALIVE_TOKEN = "FcmFetch"
val WEBSOCKET_DRAIN_TIMEOUT = 5.minutes.inWholeMilliseconds val WEBSOCKET_DRAIN_TIMEOUT: Long
get() {
return if (AppDependencies.signalServiceNetworkAccess.isCensored()) {
2.minutes.inWholeMilliseconds
} else {
5.minutes.inWholeMilliseconds
}
}
@Volatile @Volatile
private var activeCount = 0 private var activeCount = 0

View File

@@ -7,7 +7,9 @@ import android.content.Intent
import android.os.IBinder import android.os.IBinder
import androidx.annotation.VisibleForTesting import androidx.annotation.VisibleForTesting
import androidx.core.app.NotificationCompat import androidx.core.app.NotificationCompat
import kotlinx.collections.immutable.toImmutableSet import io.reactivex.rxjava3.disposables.Disposable
import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.schedulers.Schedulers
import org.signal.core.util.concurrent.SignalExecutors import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.R import org.thoughtcrime.securesms.R
@@ -52,23 +54,22 @@ import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds import kotlin.time.Duration.Companion.seconds
/** /**
* The application-level manager of our websocket connection. * The application-level manager of our incoming message processing.
* *
* This class is responsible for opening/closing the websocket based on the app's state and observing new inbound messages received on the websocket. * This class is responsible for keeping the authenticated websocket open based on the app's state for incoming messages and
* observing new inbound messages received over the websocket.
*/ */
class IncomingMessageObserver(private val context: Application, private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket) { class IncomingMessageObserver(private val context: Application, private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket) {
companion object { companion object {
private val TAG = Log.tag(IncomingMessageObserver::class.java) private val TAG = Log.tag(IncomingMessageObserver::class.java)
private const val WEB_SOCKET_KEEP_ALIVE_TOKEN = "MessageRetrieval"
/** How long we wait for the websocket to time out before we try to connect again. */ /** How long we wait for the websocket to time out before we try to connect again. */
private val websocketReadTimeout: Long private val websocketReadTimeout: Long
get() = if (censored) 30.seconds.inWholeMilliseconds else 1.minutes.inWholeMilliseconds get() = if (censored) 30.seconds.inWholeMilliseconds else 1.minutes.inWholeMilliseconds
/** How long a keep-alive token is allowed to keep the websocket open for. These are usually used for calling + FCM messages. */
private val keepAliveTokenMaxAge: Long
get() = if (censored) 2.minutes.inWholeMilliseconds else 5.minutes.inWholeMilliseconds
/** How long the websocket is allowed to keep running after the user backgrounds the app. Higher numbers allow us to rely on FCM less. */ /** How long the websocket is allowed to keep running after the user backgrounds the app. Higher numbers allow us to rely on FCM less. */
private val maxBackgroundTime: Long private val maxBackgroundTime: Long
get() = if (censored) 10.seconds.inWholeMilliseconds else 2.minutes.inWholeMilliseconds get() = if (censored) 10.seconds.inWholeMilliseconds else 2.minutes.inWholeMilliseconds
@@ -82,8 +83,6 @@ class IncomingMessageObserver(private val context: Application, private val auth
} }
private val decryptionDrainedListeners: MutableList<Runnable> = CopyOnWriteArrayList() private val decryptionDrainedListeners: MutableList<Runnable> = CopyOnWriteArrayList()
private val keepAliveTokens: MutableMap<String, Long> = mutableMapOf()
private val keepAlivePurgeCallbacks: MutableMap<String, MutableList<Runnable>> = mutableMapOf()
private val lock: ReentrantLock = ReentrantLock() private val lock: ReentrantLock = ReentrantLock()
private val connectionNecessarySemaphore = Semaphore(0) private val connectionNecessarySemaphore = Semaphore(0)
@@ -91,9 +90,8 @@ class IncomingMessageObserver(private val context: Application, private val auth
lock.withLock { lock.withLock {
AppDependencies.libsignalNetwork.onNetworkChange() AppDependencies.libsignalNetwork.onNetworkChange()
if (isNetworkUnavailable()) { if (isNetworkUnavailable()) {
Log.w(TAG, "Lost network connection. Shutting down our websocket connections and resetting the drained state.") Log.w(TAG, "Lost network connection. Resetting the drained state.")
decryptionDrained = false decryptionDrained = false
disconnect()
} }
connectionNecessarySemaphore.release() connectionNecessarySemaphore.release()
} }
@@ -103,6 +101,7 @@ class IncomingMessageObserver(private val context: Application, private val auth
private var appVisible = false private var appVisible = false
private var lastInteractionTime: Long = System.currentTimeMillis() private var lastInteractionTime: Long = System.currentTimeMillis()
private var webSocketStateDisposable = Disposable.disposed()
@Volatile @Volatile
private var terminated = false private var terminated = false
@@ -144,6 +143,17 @@ class IncomingMessageObserver(private val context: Application, private val auth
}) })
networkConnectionListener.register() networkConnectionListener.register()
webSocketStateDisposable = authWebSocket
.state
.observeOn(Schedulers.computation())
.subscribeBy {
if (it == WebSocketConnectionState.CONNECTED) {
lock.withLock {
connectionNecessarySemaphore.release()
}
}
}
} }
fun notifyRegistrationStateChanged() { fun notifyRegistrationStateChanged() {
@@ -179,23 +189,11 @@ class IncomingMessageObserver(private val context: Application, private val auth
private fun isConnectionNecessary(): Boolean { private fun isConnectionNecessary(): Boolean {
val timeIdle: Long val timeIdle: Long
val keepAliveEntries: Set<Pair<String, Long>>
val appVisibleSnapshot: Boolean val appVisibleSnapshot: Boolean
lock.withLock { lock.withLock {
appVisibleSnapshot = appVisible appVisibleSnapshot = appVisible
timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime
val keepAliveCutoffTime = System.currentTimeMillis() - keepAliveTokenMaxAge
keepAliveEntries = keepAliveTokens.entries.mapNotNull { (key, createTime) ->
if (createTime < keepAliveCutoffTime) {
Log.d(TAG, "Removed old keep web socket keep alive token $key")
keepAlivePurgeCallbacks.remove(key)?.forEach { it.run() }
null
} else {
key to createTime
}
}.toImmutableSet()
} }
val registered = SignalStore.account.isRegistered val registered = SignalStore.account.isRegistered
@@ -203,24 +201,33 @@ class IncomingMessageObserver(private val context: Application, private val auth
val hasNetwork = NetworkConstraint.isMet(context) val hasNetwork = NetworkConstraint.isMet(context)
val hasProxy = SignalStore.proxy.isProxyEnabled val hasProxy = SignalStore.proxy.isProxyEnabled
val forceWebsocket = SignalStore.internal.isWebsocketModeForced val forceWebsocket = SignalStore.internal.isWebsocketModeForced
val isRestoreDecisionPending = RemoteConfig.restoreAfterRegistration && SignalStore.registration.restoreDecisionState.isDecisionPending val websocketAlreadyOpen = isConnectionAvailable()
val canProcessIncomingMessages = canProcessIncomingMessages()
val lastInteractionString = if (appVisibleSnapshot) "N/A" else timeIdle.toString() + " ms (" + (if (timeIdle < maxBackgroundTime) "within limit" else "over limit") + ")" val lastInteractionString = if (appVisibleSnapshot) "N/A" else timeIdle.toString() + " ms (" + (if (timeIdle < maxBackgroundTime) "within limit" else "over limit") + ")"
val conclusion = registered && val conclusion = registered &&
(appVisibleSnapshot || timeIdle < maxBackgroundTime || !fcmEnabled || keepAliveEntries.isNotEmpty()) && (appVisibleSnapshot || timeIdle < maxBackgroundTime || !fcmEnabled) &&
hasNetwork && hasNetwork &&
!isRestoreDecisionPending canProcessIncomingMessages
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection" val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: $keepAliveEntries, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Pending restore: $isRestoreDecisionPending") Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Connected: $websocketAlreadyOpen, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Can process messages: $canProcessIncomingMessages")
return conclusion return conclusion
} }
private fun isConnectionAvailable(): Boolean {
return authWebSocket.stateSnapshot == WebSocketConnectionState.CONNECTED
}
private fun canProcessIncomingMessages(): Boolean {
return !(RemoteConfig.restoreAfterRegistration && SignalStore.registration.restoreDecisionState.isDecisionPending)
}
private fun waitForConnectionNecessary() { private fun waitForConnectionNecessary() {
try { try {
connectionNecessarySemaphore.drainPermits() connectionNecessarySemaphore.drainPermits()
while (!isConnectionNecessary()) { while (!isConnectionNecessary() && !(isConnectionAvailable() && canProcessIncomingMessages())) {
val numberDrained = connectionNecessarySemaphore.drainPermits() val numberDrained = connectionNecessarySemaphore.drainPermits()
if (numberDrained == 0) { if (numberDrained == 0) {
connectionNecessarySemaphore.acquire() connectionNecessarySemaphore.acquire()
@@ -235,38 +242,10 @@ class IncomingMessageObserver(private val context: Application, private val auth
Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable()) Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable())
INSTANCE_COUNT.decrementAndGet() INSTANCE_COUNT.decrementAndGet()
networkConnectionListener.unregister() networkConnectionListener.unregister()
webSocketStateDisposable.dispose()
SignalExecutors.BOUNDED.execute { SignalExecutors.BOUNDED.execute {
Log.w(TAG, "Beginning termination. ${this.hashCode()}") Log.w(TAG, "Beginning termination. ${this.hashCode()}")
terminated = true terminated = true
disconnect()
}
}
private fun disconnect() {
authWebSocket.disconnect()
}
@JvmOverloads
fun registerKeepAliveToken(key: String, runnable: Runnable? = null) {
lock.withLock {
keepAliveTokens[key] = System.currentTimeMillis()
if (runnable != null) {
if (!keepAlivePurgeCallbacks.containsKey(key)) {
keepAlivePurgeCallbacks[key] = ArrayList()
}
keepAlivePurgeCallbacks[key]?.add(runnable)
}
lastInteractionTime = System.currentTimeMillis()
connectionNecessarySemaphore.release()
}
}
fun removeKeepAliveToken(key: String) {
lock.withLock {
keepAliveTokens.remove(key)
keepAlivePurgeCallbacks.remove(key)
lastInteractionTime = System.currentTimeMillis()
connectionNecessarySemaphore.release()
} }
} }
@@ -396,9 +375,16 @@ class IncomingMessageObserver(private val context: Application, private val auth
} }
} }
authWebSocket.connect()
try { try {
while (!terminated && isConnectionNecessary()) { authWebSocket.connect()
var isConnectionNecessary = false
while (!terminated && canProcessIncomingMessages() && (isConnectionNecessary().also { isConnectionNecessary = it } || isConnectionAvailable())) {
if (isConnectionNecessary) {
authWebSocket.registerKeepAliveToken(WEB_SOCKET_KEEP_ALIVE_TOKEN)
} else {
authWebSocket.removeKeepAliveToken(WEB_SOCKET_KEEP_ALIVE_TOKEN)
}
try { try {
Log.d(TAG, "Reading message...") Log.d(TAG, "Reading message...")
@@ -461,8 +447,6 @@ class IncomingMessageObserver(private val context: Application, private val auth
attempts++ attempts++
Log.w(TAG, e) Log.w(TAG, e)
} finally { } finally {
Log.w(TAG, "Shutting down pipe...")
disconnect()
webSocketDisposable.dispose() webSocketDisposable.dispose()
} }
Log.i(TAG, "Looping...") Log.i(TAG, "Looping...")

View File

@@ -51,7 +51,6 @@ object WebSocketDrainer {
var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs
val context = AppDependencies.application val context = AppDependencies.application
val incomingMessageObserver = AppDependencies.incomingMessageObserver
val powerManager = ServiceUtil.getPowerManager(context) val powerManager = ServiceUtil.getPowerManager(context)
val doze = PowerManagerCompat.isDeviceIdleMode(powerManager) val doze = PowerManagerCompat.isDeviceIdleMode(powerManager)
@@ -70,7 +69,7 @@ object WebSocketDrainer {
val wakeLock = WakeLockUtil.acquire(AppDependencies.application, PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag) val wakeLock = WakeLockUtil.acquire(AppDependencies.application, PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag)
return try { return try {
drainAndProcess(websocketDrainTimeout, incomingMessageObserver, keepAliveToken) drainAndProcess(websocketDrainTimeout, keepAliveToken)
} finally { } finally {
WakeLockUtil.release(wakeLock, wakeLockTag) WakeLockUtil.release(wakeLock, wakeLockTag)
} }
@@ -83,7 +82,7 @@ object WebSocketDrainer {
* so that we know the queue has been drained. * so that we know the queue has been drained.
*/ */
@WorkerThread @WorkerThread
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver, keepAliveToken: String): Boolean { private fun drainAndProcess(timeout: Long, keepAliveToken: String): Boolean {
val stopwatch = Stopwatch("websocket-strategy") val stopwatch = Stopwatch("websocket-strategy")
val jobManager = AppDependencies.jobManager val jobManager = AppDependencies.jobManager
@@ -94,7 +93,7 @@ object WebSocketDrainer {
queueListener queueListener
) )
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout, keepAliveToken) val successfullyDrained = blockUntilWebsocketDrained(timeout, keepAliveToken)
if (!successfullyDrained) { if (!successfullyDrained) {
return false return false
} }
@@ -116,19 +115,17 @@ object WebSocketDrainer {
return true return true
} }
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long, keepAliveToken: String): Boolean { private fun blockUntilWebsocketDrained(timeoutMs: Long, keepAliveToken: String): Boolean {
try { try {
val latch = CountDownLatch(1) val latch = CountDownLatch(1)
var success = false var success = false
incomingMessageObserver.registerKeepAliveToken(keepAliveToken) { AppDependencies.authWebSocket.registerKeepAliveToken(keepAliveToken)
Log.w(TAG, "Keep alive token purged")
latch.countDown() AppDependencies.incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
}
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
override fun run() { override fun run() {
success = true success = true
latch.countDown() latch.countDown()
incomingMessageObserver.removeDecryptionDrainedListener(this) AppDependencies.incomingMessageObserver.removeDecryptionDrainedListener(this)
} }
}) })
@@ -142,7 +139,7 @@ object WebSocketDrainer {
false false
} }
} finally { } finally {
incomingMessageObserver.removeKeepAliveToken(keepAliveToken) AppDependencies.authWebSocket.removeKeepAliveToken(keepAliveToken)
} }
} }

View File

@@ -5,6 +5,7 @@ import androidx.annotation.NonNull;
import org.signal.core.util.logging.Log; import org.signal.core.util.logging.Log;
import org.thoughtcrime.securesms.dependencies.AppDependencies; import org.thoughtcrime.securesms.dependencies.AppDependencies;
import org.thoughtcrime.securesms.keyvalue.SignalStore; import org.thoughtcrime.securesms.keyvalue.SignalStore;
import org.whispersystems.signalservice.api.websocket.SignalWebSocket;
import java.io.IOException; import java.io.IOException;
@@ -50,11 +51,13 @@ public final class DeviceTransferBlockingInterceptor implements Interceptor {
public void blockNetwork() { public void blockNetwork() {
blockNetworking = true; blockNetworking = true;
SignalWebSocket.setCanConnect(false);
AppDependencies.resetNetwork(); AppDependencies.resetNetwork();
} }
public void unblockNetwork() { public void unblockNetwork() {
blockNetworking = false; blockNetworking = false;
SignalWebSocket.setCanConnect(true);
AppDependencies.startNetwork(); AppDependencies.startNetwork();
} }
} }

View File

@@ -109,7 +109,7 @@ class SignalWebSocketHealthMonitor(
} }
private fun sendKeepAlives(): Boolean { private fun sendKeepAlives(): Boolean {
return needsKeepAlive && webSocket?.shouldSendKeepAlives == true return needsKeepAlive && webSocket?.shouldSendKeepAlives() == true
} }
/** /**

View File

@@ -79,6 +79,7 @@ import org.whispersystems.signalservice.api.AccountEntropyPool
import org.whispersystems.signalservice.api.SvrNoDataException import org.whispersystems.signalservice.api.SvrNoDataException
import org.whispersystems.signalservice.api.kbs.MasterKey import org.whispersystems.signalservice.api.kbs.MasterKey
import org.whispersystems.signalservice.api.svr.Svr3Credentials import org.whispersystems.signalservice.api.svr.Svr3Credentials
import org.whispersystems.signalservice.api.websocket.WebSocketUnavailableException
import org.whispersystems.signalservice.internal.push.AuthCredentials import org.whispersystems.signalservice.internal.push.AuthCredentials
import java.io.IOException import java.io.IOException
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
@@ -875,7 +876,11 @@ class RegistrationViewModel : ViewModel() {
SignalStore.registration.localRegistrationMetadata = metadata SignalStore.registration.localRegistrationMetadata = metadata
RegistrationRepository.registerAccountLocally(context, metadata) RegistrationRepository.registerAccountLocally(context, metadata)
AppDependencies.authWebSocket.connect() try {
AppDependencies.authWebSocket.connect()
} catch (e: WebSocketUnavailableException) {
Log.w(TAG, "Unable to start auth websocket", e)
}
if (!remoteResult.storageCapable && SignalStore.registration.restoreDecisionState.isDecisionPending) { if (!remoteResult.storageCapable && SignalStore.registration.restoreDecisionState.isDecisionPending) {
Log.v(TAG, "Not storage capable and still pending restore decision, likely an account with no data to restore, skipping post register restore") Log.v(TAG, "Not storage capable and still pending restore decision, likely an account with no data to restore, skipping post register restore")

View File

@@ -17,7 +17,6 @@ import android.net.ConnectivityManager
import android.os.Build import android.os.Build
import android.telephony.PhoneStateListener import android.telephony.PhoneStateListener
import android.telephony.TelephonyManager import android.telephony.TelephonyManager
import androidx.annotation.MainThread
import androidx.annotation.RequiresApi import androidx.annotation.RequiresApi
import androidx.core.app.NotificationManagerCompat import androidx.core.app.NotificationManagerCompat
import androidx.core.os.bundleOf import androidx.core.os.bundleOf
@@ -27,7 +26,6 @@ import io.reactivex.rxjava3.disposables.Disposable
import io.reactivex.rxjava3.kotlin.subscribeBy import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.schedulers.Schedulers import io.reactivex.rxjava3.schedulers.Schedulers
import org.signal.core.util.PendingIntentFlags import org.signal.core.util.PendingIntentFlags
import org.signal.core.util.ThreadUtil
import org.signal.core.util.logging.Log import org.signal.core.util.logging.Log
import org.thoughtcrime.securesms.dependencies.AppDependencies import org.thoughtcrime.securesms.dependencies.AppDependencies
import org.thoughtcrime.securesms.jobs.UnableToStartException import org.thoughtcrime.securesms.jobs.UnableToStartException
@@ -45,8 +43,6 @@ import org.thoughtcrime.securesms.webrtc.audio.SignalAudioManager.Companion.crea
import org.thoughtcrime.securesms.webrtc.locks.LockManager import org.thoughtcrime.securesms.webrtc.locks.LockManager
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import kotlin.time.Duration
import kotlin.time.Duration.Companion.minutes
/** /**
* Entry point for [SignalCallManager] and friends to interact with the Android system. * Entry point for [SignalCallManager] and friends to interact with the Android system.
@@ -61,6 +57,8 @@ class ActiveCallManager(
companion object { companion object {
private val TAG = Log.tag(ActiveCallManager::class.java) private val TAG = Log.tag(ActiveCallManager::class.java)
private const val WEBSOCKET_KEEP_ALIVE_TOKEN: String = "ActiveCall"
private val requiresAsyncNotificationLoad = Build.VERSION.SDK_INT <= 29 private val requiresAsyncNotificationLoad = Build.VERSION.SDK_INT <= 29
private var activeCallManager: ActiveCallManager? = null private var activeCallManager: ActiveCallManager? = null
@@ -142,7 +140,6 @@ class ActiveCallManager(
private var networkReceiver: NetworkReceiver? = null private var networkReceiver: NetworkReceiver? = null
private var powerButtonReceiver: PowerButtonReceiver? = null private var powerButtonReceiver: PowerButtonReceiver? = null
private var uncaughtExceptionHandlerManager: UncaughtExceptionHandlerManager? = null private var uncaughtExceptionHandlerManager: UncaughtExceptionHandlerManager? = null
private val webSocketKeepAliveTask: WebSocketKeepAliveTask = WebSocketKeepAliveTask()
private var signalAudioManager: SignalAudioManager? = null private var signalAudioManager: SignalAudioManager? = null
private var previousNotificationId = -1 private var previousNotificationId = -1
private var previousNotificationDisposable = Disposable.disposed() private var previousNotificationDisposable = Disposable.disposed()
@@ -153,7 +150,8 @@ class ActiveCallManager(
registerUncaughtExceptionHandler() registerUncaughtExceptionHandler()
registerNetworkReceiver() registerNetworkReceiver()
webSocketKeepAliveTask.start() AppDependencies.authWebSocket.registerKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
AppDependencies.unauthWebSocket.registerKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
} }
fun shutdown() { fun shutdown() {
@@ -170,7 +168,8 @@ class ActiveCallManager(
unregisterNetworkReceiver() unregisterNetworkReceiver()
unregisterPowerButtonReceiver() unregisterPowerButtonReceiver()
webSocketKeepAliveTask.stop() AppDependencies.authWebSocket.removeKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
AppDependencies.unauthWebSocket.removeKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
if (!ActiveCallForegroundService.stop(application) && previousNotificationId != -1) { if (!ActiveCallForegroundService.stop(application) && previousNotificationId != -1) {
NotificationManagerCompat.from(application).cancel(previousNotificationId) NotificationManagerCompat.from(application).cancel(previousNotificationId)
@@ -433,42 +432,6 @@ class ActiveCallManager(
} }
} }
/**
* Periodically request the web socket stay open if we are doing anything call related.
*/
private class WebSocketKeepAliveTask : Runnable {
companion object {
private val REQUEST_WEBSOCKET_STAY_OPEN_DELAY: Duration = 1.minutes
private val WEBSOCKET_KEEP_ALIVE_TOKEN: String = WebSocketKeepAliveTask::class.java.simpleName
}
private var keepRunning = false
@MainThread
fun start() {
if (!keepRunning) {
keepRunning = true
run()
}
}
@MainThread
fun stop() {
keepRunning = false
ThreadUtil.cancelRunnableOnMain(this)
AppDependencies.incomingMessageObserver.removeKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
}
@MainThread
override fun run() {
if (keepRunning) {
AppDependencies.incomingMessageObserver.registerKeepAliveToken(WEBSOCKET_KEEP_ALIVE_TOKEN)
ThreadUtil.runOnMainDelayed(this, REQUEST_WEBSOCKET_STAY_OPEN_DELAY.inWholeMilliseconds)
}
}
}
private class NetworkReceiver : BroadcastReceiver() { private class NetworkReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent) { override fun onReceive(context: Context, intent: Intent) {
val connectivityManager = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager val connectivityManager = context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager

View File

@@ -42,7 +42,7 @@ public final class SignalProxyUtil {
AppDependencies.resetNetwork(); AppDependencies.resetNetwork();
} }
AppDependencies.startNetwork(); SignalExecutors.UNBOUNDED.execute(AppDependencies::startNetwork);
} }
/** /**

View File

@@ -78,7 +78,7 @@ class AccountApi(private val authWebSocket: SignalWebSocket.AuthenticatedWebSock
/** /**
* PUT /v1/accounts/registration_lock * PUT /v1/accounts/registration_lock
* - 200: Success * - 204: Success
*/ */
fun enableRegistrationLock(registrationLock: String): NetworkResult<Unit> { fun enableRegistrationLock(registrationLock: String): NetworkResult<Unit> {
val request = WebSocketRequestMessage.put("/v1/accounts/registration_lock", PushServiceSocket.RegistrationLockV2(registrationLock)) val request = WebSocketRequestMessage.put("/v1/accounts/registration_lock", PushServiceSocket.RegistrationLockV2(registrationLock))

View File

@@ -17,6 +17,10 @@ public interface CredentialsProvider {
int getDeviceId(); int getDeviceId();
String getPassword(); String getPassword();
default boolean isInvalid() {
return (getAci() == null && getE164() == null) || getPassword() == null;
}
default String getUsername() { default String getUsername() {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append(getAci().toString()); sb.append(getAci().toString());

View File

@@ -16,6 +16,7 @@ import org.signal.core.util.logging.Log
import org.signal.core.util.orNull import org.signal.core.util.orNull
import org.whispersystems.signalservice.api.crypto.SealedSenderAccess import org.whispersystems.signalservice.api.crypto.SealedSenderAccess
import org.whispersystems.signalservice.api.messages.EnvelopeResponse import org.whispersystems.signalservice.api.messages.EnvelopeResponse
import org.whispersystems.signalservice.api.util.SleepTimer
import org.whispersystems.signalservice.internal.push.Envelope import org.whispersystems.signalservice.internal.push.Envelope
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection import org.whispersystems.signalservice.internal.websocket.WebSocketConnection
import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage import org.whispersystems.signalservice.internal.websocket.WebSocketRequestMessage
@@ -24,42 +25,56 @@ import org.whispersystems.signalservice.internal.websocket.WebsocketResponse
import java.io.IOException import java.io.IOException
import java.util.concurrent.TimeoutException import java.util.concurrent.TimeoutException
import kotlin.time.Duration import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
/** /**
* Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket * Base wrapper around a [WebSocketConnection] to provide a more developer friend interface to websocket
* interactions. * interactions.
*/ */
sealed class SignalWebSocket( sealed class SignalWebSocket(
private val createConnection: () -> WebSocketConnection private val connectionFactory: WebSocketFactory,
val sleepTimer: SleepTimer,
private val disconnectTimeout: Duration
) { ) {
companion object { companion object {
private val TAG = Log.tag(SignalWebSocket::class) private val TAG = Log.tag(SignalWebSocket::class)
const val SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp" const val SERVER_DELIVERED_TIMESTAMP_HEADER = "X-Signal-Timestamp"
const val FOREGROUND_KEEPALIVE = "Foregrounded"
/**
* Set to false to prevent web sockets from connecting. After setting back to true the caller
* must manually start the sockets again by calling [connect].
*/
@Volatile
@JvmStatic
var canConnect: Boolean = true
} }
private var connection: WebSocketConnection? = null private var connection: WebSocketConnection? = null
private val connectionName
get() = connection?.name ?: "[null]"
private val _state: BehaviorSubject<WebSocketConnectionState> = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED) private val _state: BehaviorSubject<WebSocketConnectionState> = BehaviorSubject.createDefault(WebSocketConnectionState.DISCONNECTED)
protected var disposable: CompositeDisposable = CompositeDisposable() protected var disposable: CompositeDisposable = CompositeDisposable()
private var canConnect = false private val keepAliveTokens: MutableSet<String> = mutableSetOf()
var shouldSendKeepAlives: Boolean = true
set(value) {
field = value
keepAliveChangedListener?.invoke()
}
var keepAliveChangedListener: (() -> Unit)? = null var keepAliveChangedListener: (() -> Unit)? = null
private var delayedDisconnectThread: DelayedDisconnectThread? = null
val state: Observable<WebSocketConnectionState> = _state val state: Observable<WebSocketConnectionState> = _state
val stateSnapshot: WebSocketConnectionState
get() = _state.value!!
/** /**
* Indicate that WebSocketConnection can now be made and attempt to connect. * Indicate that WebSocketConnection can now be made and attempt to connect.
*/ */
@Synchronized @Synchronized
@Throws(WebSocketUnavailableException::class)
fun connect() { fun connect() {
canConnect = true
getWebSocket() getWebSocket()
} }
@@ -68,11 +83,6 @@ sealed class SignalWebSocket(
*/ */
@Synchronized @Synchronized
fun disconnect() { fun disconnect() {
canConnect = false
disconnectInternal()
}
private fun disconnectInternal() {
if (connection != null) { if (connection != null) {
disposable.dispose() disposable.dispose()
@@ -89,12 +99,53 @@ sealed class SignalWebSocket(
@Throws(IOException::class) @Throws(IOException::class)
fun sendKeepAlive() { fun sendKeepAlive() {
if (canConnect) { if (canConnect) {
Log.v(TAG, "$connectionName keepAliveTokens: $keepAliveTokens")
getWebSocket().sendKeepAlive() getWebSocket().sendKeepAlive()
} }
} }
@Synchronized
fun shouldSendKeepAlives(): Boolean {
return keepAliveTokens.isNotEmpty()
}
@Synchronized
fun registerKeepAliveToken(token: String) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = null
val changed = keepAliveTokens.add(token)
if (changed) {
Log.v(TAG, "$connectionName Adding keepAliveToken: $token, current: $keepAliveTokens")
}
if (canConnect) {
try {
connect()
} catch (e: WebSocketUnavailableException) {
Log.w(TAG, "$connectionName Keep alive requested, but connection not available", e)
}
} else {
Log.w(TAG, "$connectionName Keep alive requested, but connection not available")
}
if (changed) {
keepAliveChangedListener?.invoke()
}
}
@Synchronized
fun removeKeepAliveToken(token: String) {
if (keepAliveTokens.remove(token)) {
Log.v(TAG, "$connectionName Removing keepAliveToken: $token, remaining: $keepAliveTokens")
startDelayedDisconnectIfNecessary()
keepAliveChangedListener?.invoke()
}
}
fun request(request: WebSocketRequestMessage): Single<WebsocketResponse> { fun request(request: WebSocketRequestMessage): Single<WebsocketResponse> {
return try { return try {
delayedDisconnectThread?.resetLastInteractionTime()
getWebSocket().sendRequest(request) getWebSocket().sendRequest(request)
} catch (e: IOException) { } catch (e: IOException) {
Single.error(e) Single.error(e)
@@ -103,6 +154,7 @@ sealed class SignalWebSocket(
fun request(request: WebSocketRequestMessage, timeout: Duration): Single<WebsocketResponse> { fun request(request: WebSocketRequestMessage, timeout: Duration): Single<WebsocketResponse> {
return try { return try {
delayedDisconnectThread?.resetLastInteractionTime()
getWebSocket().sendRequest(request, timeout.inWholeSeconds) getWebSocket().sendRequest(request, timeout.inWholeSeconds)
} catch (e: IOException) { } catch (e: IOException) {
Single.error(e) Single.error(e)
@@ -125,7 +177,7 @@ sealed class SignalWebSocket(
disposable.dispose() disposable.dispose()
disposable = CompositeDisposable() disposable = CompositeDisposable()
val newConnection = createConnection() val newConnection = connectionFactory.createConnection()
newConnection newConnection
.connect() .connect()
@@ -135,15 +187,70 @@ sealed class SignalWebSocket(
.addTo(disposable) .addTo(disposable)
this.connection = newConnection this.connection = newConnection
startDelayedDisconnectIfNecessary()
} }
return connection!! return connection!!
} }
private fun startDelayedDisconnectIfNecessary() {
if (connection.isAlive() && keepAliveTokens.isEmpty()) {
delayedDisconnectThread?.abort()
delayedDisconnectThread = DelayedDisconnectThread().also { it.start() }
}
}
@Synchronized @Synchronized
fun forceNewWebSocket() { fun forceNewWebSocket() {
Log.i(TAG, "Forcing new WebSockets connection: ${connection?.name ?: "[null]"} canConnect: $canConnect") Log.i(TAG, "$connectionName Forcing new WebSocket, canConnect: $canConnect")
disconnectInternal() disconnect()
}
/**
* Allow the WebSocket to self destruct if there are no keep alive tokens and it's been longer
* than [disconnectTimeout] since the last request was made.
*/
private inner class DelayedDisconnectThread : Thread() {
private var abort = false
@Volatile
private var lastInteractionTime = Duration.ZERO
fun abort() {
if (!abort && isAlive) {
Log.v(TAG, "$connectionName Scheduled disconnect aborted.")
abort = true
interrupt()
}
}
fun resetLastInteractionTime() {
lastInteractionTime = System.currentTimeMillis().milliseconds
}
override fun run() {
lastInteractionTime = System.currentTimeMillis().milliseconds
try {
while (!abort && (lastInteractionTime + disconnectTimeout) > System.currentTimeMillis().milliseconds) {
val now = System.currentTimeMillis().milliseconds
if (lastInteractionTime > now) {
lastInteractionTime = now
}
val sleepDuration = (lastInteractionTime + disconnectTimeout) - now
Log.v(TAG, "$connectionName Disconnect scheduled in $sleepDuration")
sleepTimer.sleep(sleepDuration.inWholeMilliseconds)
}
} catch (_: InterruptedException) { }
if (!abort && !shouldSendKeepAlives()) {
disconnect()
}
}
}
private fun WebSocketConnection?.isAlive(): Boolean {
return this?.isDead() == false
} }
protected fun WebSocketRequestMessage.isSignalServiceEnvelope(): Boolean { protected fun WebSocketRequestMessage.isSignalServiceEnvelope(): Boolean {
@@ -173,7 +280,7 @@ sealed class SignalWebSocket(
/** /**
* WebSocket type for communicating with the server without authenticating. Also known as "unidentified". * WebSocket type for communicating with the server without authenticating. Also known as "unidentified".
*/ */
class UnauthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) { class UnauthenticatedWebSocket(connectionFactory: WebSocketFactory, sleepTimer: SleepTimer, disconnectTimeoutMs: Long) : SignalWebSocket(connectionFactory, sleepTimer, disconnectTimeoutMs.milliseconds) {
fun request(requestMessage: WebSocketRequestMessage, sealedSenderAccess: SealedSenderAccess): Single<WebsocketResponse> { fun request(requestMessage: WebSocketRequestMessage, sealedSenderAccess: SealedSenderAccess): Single<WebsocketResponse> {
val headers: MutableList<String> = requestMessage.headers.toMutableList() val headers: MutableList<String> = requestMessage.headers.toMutableList()
headers.add(sealedSenderAccess.header) headers.add(sealedSenderAccess.header)
@@ -184,8 +291,7 @@ sealed class SignalWebSocket(
.build() .build()
try { try {
return getWebSocket() return request(message)
.sendRequest(message)
.flatMap<WebsocketResponse> { response -> .flatMap<WebsocketResponse> { response ->
if (response.status == 401) { if (response.status == 401) {
val fallback = sealedSenderAccess.switchToFallback() val fallback = sealedSenderAccess.switchToFallback()
@@ -204,7 +310,7 @@ sealed class SignalWebSocket(
/** /**
* WebSocket type for communicating with the server with authentication. Also known as "identified". * WebSocket type for communicating with the server with authentication. Also known as "identified".
*/ */
class AuthenticatedWebSocket(createConnection: () -> WebSocketConnection) : SignalWebSocket(createConnection) { class AuthenticatedWebSocket(connectionFactory: WebSocketFactory, sleepTimer: SleepTimer, disconnectTimeoutMs: Long) : SignalWebSocket(connectionFactory, sleepTimer, disconnectTimeoutMs.milliseconds) {
/** /**
* The reads a batch of messages off of the websocket. * The reads a batch of messages off of the websocket.

View File

@@ -3,6 +3,5 @@ package org.whispersystems.signalservice.api.websocket;
import org.whispersystems.signalservice.internal.websocket.WebSocketConnection; import org.whispersystems.signalservice.internal.websocket.WebSocketConnection;
public interface WebSocketFactory { public interface WebSocketFactory {
WebSocketConnection createWebSocket(); WebSocketConnection createConnection() throws WebSocketUnavailableException;
WebSocketConnection createUnidentifiedWebSocket();
} }

View File

@@ -4,11 +4,15 @@ import java.io.IOException;
/** /**
* Thrown when the WebSocket is not available for use by runtime policy. Currently, the * Thrown when the WebSocket is not available for use by runtime policy. Currently, the
* WebSocket is only available when the app is in the foreground and requested via IncomingMessageObserver. * WebSocket is only unavailable when networking is blocked by a device transfer or if
* Or, when using WebSocket Strategy. * requesting to connect via auth but provide no auth credentials.
*/ */
public final class WebSocketUnavailableException extends IOException { public final class WebSocketUnavailableException extends IOException {
public WebSocketUnavailableException() { public WebSocketUnavailableException() {
super("WebSocket not currently available."); super("WebSocket not currently available.");
} }
public WebSocketUnavailableException(String reason) {
super("WebSocket not currently available. Reason: " + reason);
}
} }