mirror of
https://github.com/signalapp/Signal-Android.git
synced 2026-04-21 09:20:19 +01:00
Improve auth WebSocket lifecycle.
This commit is contained in:
committed by
Alex Hart
parent
6bbd899507
commit
323697dfc9
@@ -7,7 +7,9 @@ import android.content.Intent
|
||||
import android.os.IBinder
|
||||
import androidx.annotation.VisibleForTesting
|
||||
import androidx.core.app.NotificationCompat
|
||||
import kotlinx.collections.immutable.toImmutableSet
|
||||
import io.reactivex.rxjava3.disposables.Disposable
|
||||
import io.reactivex.rxjava3.kotlin.subscribeBy
|
||||
import io.reactivex.rxjava3.schedulers.Schedulers
|
||||
import org.signal.core.util.concurrent.SignalExecutors
|
||||
import org.signal.core.util.logging.Log
|
||||
import org.thoughtcrime.securesms.R
|
||||
@@ -52,23 +54,22 @@ import kotlin.time.Duration.Companion.minutes
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
/**
|
||||
* The application-level manager of our websocket connection.
|
||||
* The application-level manager of our incoming message processing.
|
||||
*
|
||||
* This class is responsible for opening/closing the websocket based on the app's state and observing new inbound messages received on the websocket.
|
||||
* This class is responsible for keeping the authenticated websocket open based on the app's state for incoming messages and
|
||||
* observing new inbound messages received over the websocket.
|
||||
*/
|
||||
class IncomingMessageObserver(private val context: Application, private val authWebSocket: SignalWebSocket.AuthenticatedWebSocket) {
|
||||
|
||||
companion object {
|
||||
private val TAG = Log.tag(IncomingMessageObserver::class.java)
|
||||
|
||||
private const val WEB_SOCKET_KEEP_ALIVE_TOKEN = "MessageRetrieval"
|
||||
|
||||
/** How long we wait for the websocket to time out before we try to connect again. */
|
||||
private val websocketReadTimeout: Long
|
||||
get() = if (censored) 30.seconds.inWholeMilliseconds else 1.minutes.inWholeMilliseconds
|
||||
|
||||
/** How long a keep-alive token is allowed to keep the websocket open for. These are usually used for calling + FCM messages. */
|
||||
private val keepAliveTokenMaxAge: Long
|
||||
get() = if (censored) 2.minutes.inWholeMilliseconds else 5.minutes.inWholeMilliseconds
|
||||
|
||||
/** How long the websocket is allowed to keep running after the user backgrounds the app. Higher numbers allow us to rely on FCM less. */
|
||||
private val maxBackgroundTime: Long
|
||||
get() = if (censored) 10.seconds.inWholeMilliseconds else 2.minutes.inWholeMilliseconds
|
||||
@@ -82,8 +83,6 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
}
|
||||
|
||||
private val decryptionDrainedListeners: MutableList<Runnable> = CopyOnWriteArrayList()
|
||||
private val keepAliveTokens: MutableMap<String, Long> = mutableMapOf()
|
||||
private val keepAlivePurgeCallbacks: MutableMap<String, MutableList<Runnable>> = mutableMapOf()
|
||||
|
||||
private val lock: ReentrantLock = ReentrantLock()
|
||||
private val connectionNecessarySemaphore = Semaphore(0)
|
||||
@@ -91,9 +90,8 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
lock.withLock {
|
||||
AppDependencies.libsignalNetwork.onNetworkChange()
|
||||
if (isNetworkUnavailable()) {
|
||||
Log.w(TAG, "Lost network connection. Shutting down our websocket connections and resetting the drained state.")
|
||||
Log.w(TAG, "Lost network connection. Resetting the drained state.")
|
||||
decryptionDrained = false
|
||||
disconnect()
|
||||
}
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
@@ -103,6 +101,7 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
|
||||
private var appVisible = false
|
||||
private var lastInteractionTime: Long = System.currentTimeMillis()
|
||||
private var webSocketStateDisposable = Disposable.disposed()
|
||||
|
||||
@Volatile
|
||||
private var terminated = false
|
||||
@@ -144,6 +143,17 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
})
|
||||
|
||||
networkConnectionListener.register()
|
||||
|
||||
webSocketStateDisposable = authWebSocket
|
||||
.state
|
||||
.observeOn(Schedulers.computation())
|
||||
.subscribeBy {
|
||||
if (it == WebSocketConnectionState.CONNECTED) {
|
||||
lock.withLock {
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun notifyRegistrationStateChanged() {
|
||||
@@ -179,23 +189,11 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
|
||||
private fun isConnectionNecessary(): Boolean {
|
||||
val timeIdle: Long
|
||||
val keepAliveEntries: Set<Pair<String, Long>>
|
||||
val appVisibleSnapshot: Boolean
|
||||
|
||||
lock.withLock {
|
||||
appVisibleSnapshot = appVisible
|
||||
timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime
|
||||
|
||||
val keepAliveCutoffTime = System.currentTimeMillis() - keepAliveTokenMaxAge
|
||||
keepAliveEntries = keepAliveTokens.entries.mapNotNull { (key, createTime) ->
|
||||
if (createTime < keepAliveCutoffTime) {
|
||||
Log.d(TAG, "Removed old keep web socket keep alive token $key")
|
||||
keepAlivePurgeCallbacks.remove(key)?.forEach { it.run() }
|
||||
null
|
||||
} else {
|
||||
key to createTime
|
||||
}
|
||||
}.toImmutableSet()
|
||||
}
|
||||
|
||||
val registered = SignalStore.account.isRegistered
|
||||
@@ -203,24 +201,33 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
val hasNetwork = NetworkConstraint.isMet(context)
|
||||
val hasProxy = SignalStore.proxy.isProxyEnabled
|
||||
val forceWebsocket = SignalStore.internal.isWebsocketModeForced
|
||||
val isRestoreDecisionPending = RemoteConfig.restoreAfterRegistration && SignalStore.registration.restoreDecisionState.isDecisionPending
|
||||
val websocketAlreadyOpen = isConnectionAvailable()
|
||||
val canProcessIncomingMessages = canProcessIncomingMessages()
|
||||
|
||||
val lastInteractionString = if (appVisibleSnapshot) "N/A" else timeIdle.toString() + " ms (" + (if (timeIdle < maxBackgroundTime) "within limit" else "over limit") + ")"
|
||||
val conclusion = registered &&
|
||||
(appVisibleSnapshot || timeIdle < maxBackgroundTime || !fcmEnabled || keepAliveEntries.isNotEmpty()) &&
|
||||
(appVisibleSnapshot || timeIdle < maxBackgroundTime || !fcmEnabled) &&
|
||||
hasNetwork &&
|
||||
!isRestoreDecisionPending
|
||||
canProcessIncomingMessages
|
||||
|
||||
val needsConnectionString = if (conclusion) "Needs Connection" else "Does Not Need Connection"
|
||||
|
||||
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, Stay open requests: $keepAliveEntries, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Pending restore: $isRestoreDecisionPending")
|
||||
Log.d(TAG, "[$needsConnectionString] Network: $hasNetwork, Foreground: $appVisibleSnapshot, Time Since Last Interaction: $lastInteractionString, FCM: $fcmEnabled, WS Connected: $websocketAlreadyOpen, Registered: $registered, Proxy: $hasProxy, Force websocket: $forceWebsocket, Can process messages: $canProcessIncomingMessages")
|
||||
return conclusion
|
||||
}
|
||||
|
||||
private fun isConnectionAvailable(): Boolean {
|
||||
return authWebSocket.stateSnapshot == WebSocketConnectionState.CONNECTED
|
||||
}
|
||||
|
||||
private fun canProcessIncomingMessages(): Boolean {
|
||||
return !(RemoteConfig.restoreAfterRegistration && SignalStore.registration.restoreDecisionState.isDecisionPending)
|
||||
}
|
||||
|
||||
private fun waitForConnectionNecessary() {
|
||||
try {
|
||||
connectionNecessarySemaphore.drainPermits()
|
||||
while (!isConnectionNecessary()) {
|
||||
while (!isConnectionNecessary() && !(isConnectionAvailable() && canProcessIncomingMessages())) {
|
||||
val numberDrained = connectionNecessarySemaphore.drainPermits()
|
||||
if (numberDrained == 0) {
|
||||
connectionNecessarySemaphore.acquire()
|
||||
@@ -235,38 +242,10 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
Log.w(TAG, "Termination Enqueued! ${this.hashCode()}", Throwable())
|
||||
INSTANCE_COUNT.decrementAndGet()
|
||||
networkConnectionListener.unregister()
|
||||
webSocketStateDisposable.dispose()
|
||||
SignalExecutors.BOUNDED.execute {
|
||||
Log.w(TAG, "Beginning termination. ${this.hashCode()}")
|
||||
terminated = true
|
||||
disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
private fun disconnect() {
|
||||
authWebSocket.disconnect()
|
||||
}
|
||||
|
||||
@JvmOverloads
|
||||
fun registerKeepAliveToken(key: String, runnable: Runnable? = null) {
|
||||
lock.withLock {
|
||||
keepAliveTokens[key] = System.currentTimeMillis()
|
||||
if (runnable != null) {
|
||||
if (!keepAlivePurgeCallbacks.containsKey(key)) {
|
||||
keepAlivePurgeCallbacks[key] = ArrayList()
|
||||
}
|
||||
keepAlivePurgeCallbacks[key]?.add(runnable)
|
||||
}
|
||||
lastInteractionTime = System.currentTimeMillis()
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
}
|
||||
|
||||
fun removeKeepAliveToken(key: String) {
|
||||
lock.withLock {
|
||||
keepAliveTokens.remove(key)
|
||||
keepAlivePurgeCallbacks.remove(key)
|
||||
lastInteractionTime = System.currentTimeMillis()
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,9 +375,16 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
}
|
||||
}
|
||||
|
||||
authWebSocket.connect()
|
||||
try {
|
||||
while (!terminated && isConnectionNecessary()) {
|
||||
authWebSocket.connect()
|
||||
var isConnectionNecessary = false
|
||||
while (!terminated && canProcessIncomingMessages() && (isConnectionNecessary().also { isConnectionNecessary = it } || isConnectionAvailable())) {
|
||||
if (isConnectionNecessary) {
|
||||
authWebSocket.registerKeepAliveToken(WEB_SOCKET_KEEP_ALIVE_TOKEN)
|
||||
} else {
|
||||
authWebSocket.removeKeepAliveToken(WEB_SOCKET_KEEP_ALIVE_TOKEN)
|
||||
}
|
||||
|
||||
try {
|
||||
Log.d(TAG, "Reading message...")
|
||||
|
||||
@@ -461,8 +447,6 @@ class IncomingMessageObserver(private val context: Application, private val auth
|
||||
attempts++
|
||||
Log.w(TAG, e)
|
||||
} finally {
|
||||
Log.w(TAG, "Shutting down pipe...")
|
||||
disconnect()
|
||||
webSocketDisposable.dispose()
|
||||
}
|
||||
Log.i(TAG, "Looping...")
|
||||
|
||||
@@ -51,7 +51,6 @@ object WebSocketDrainer {
|
||||
var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs
|
||||
|
||||
val context = AppDependencies.application
|
||||
val incomingMessageObserver = AppDependencies.incomingMessageObserver
|
||||
val powerManager = ServiceUtil.getPowerManager(context)
|
||||
|
||||
val doze = PowerManagerCompat.isDeviceIdleMode(powerManager)
|
||||
@@ -70,7 +69,7 @@ object WebSocketDrainer {
|
||||
val wakeLock = WakeLockUtil.acquire(AppDependencies.application, PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag)
|
||||
|
||||
return try {
|
||||
drainAndProcess(websocketDrainTimeout, incomingMessageObserver, keepAliveToken)
|
||||
drainAndProcess(websocketDrainTimeout, keepAliveToken)
|
||||
} finally {
|
||||
WakeLockUtil.release(wakeLock, wakeLockTag)
|
||||
}
|
||||
@@ -83,7 +82,7 @@ object WebSocketDrainer {
|
||||
* so that we know the queue has been drained.
|
||||
*/
|
||||
@WorkerThread
|
||||
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver, keepAliveToken: String): Boolean {
|
||||
private fun drainAndProcess(timeout: Long, keepAliveToken: String): Boolean {
|
||||
val stopwatch = Stopwatch("websocket-strategy")
|
||||
|
||||
val jobManager = AppDependencies.jobManager
|
||||
@@ -94,7 +93,7 @@ object WebSocketDrainer {
|
||||
queueListener
|
||||
)
|
||||
|
||||
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout, keepAliveToken)
|
||||
val successfullyDrained = blockUntilWebsocketDrained(timeout, keepAliveToken)
|
||||
if (!successfullyDrained) {
|
||||
return false
|
||||
}
|
||||
@@ -116,19 +115,17 @@ object WebSocketDrainer {
|
||||
return true
|
||||
}
|
||||
|
||||
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long, keepAliveToken: String): Boolean {
|
||||
private fun blockUntilWebsocketDrained(timeoutMs: Long, keepAliveToken: String): Boolean {
|
||||
try {
|
||||
val latch = CountDownLatch(1)
|
||||
var success = false
|
||||
incomingMessageObserver.registerKeepAliveToken(keepAliveToken) {
|
||||
Log.w(TAG, "Keep alive token purged")
|
||||
latch.countDown()
|
||||
}
|
||||
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
|
||||
AppDependencies.authWebSocket.registerKeepAliveToken(keepAliveToken)
|
||||
|
||||
AppDependencies.incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
|
||||
override fun run() {
|
||||
success = true
|
||||
latch.countDown()
|
||||
incomingMessageObserver.removeDecryptionDrainedListener(this)
|
||||
AppDependencies.incomingMessageObserver.removeDecryptionDrainedListener(this)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -142,7 +139,7 @@ object WebSocketDrainer {
|
||||
false
|
||||
}
|
||||
} finally {
|
||||
incomingMessageObserver.removeKeepAliveToken(keepAliveToken)
|
||||
AppDependencies.authWebSocket.removeKeepAliveToken(keepAliveToken)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user