Fail websocket drain if keepalive token is removed.

This commit is contained in:
Clark
2023-09-13 12:24:54 -04:00
committed by Alex Hart
parent 11e0dd18d3
commit 51e46db42d
3 changed files with 48 additions and 29 deletions

View File

@@ -43,6 +43,8 @@ object FcmFetchManager {
private val TAG = Log.tag(FcmFetchManager::class.java)
private val EXECUTOR = SerialMonoLifoExecutor(SignalExecutors.UNBOUNDED)
private val KEEP_ALIVE_TOKEN = "FcmFetch"
val WEBSOCKET_DRAIN_TIMEOUT = 5.minutes.inWholeMilliseconds
@Volatile
@@ -140,7 +142,7 @@ object FcmFetchManager {
@JvmStatic
fun retrieveMessages(context: Context): Boolean {
val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT)
val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT, KEEP_ALIVE_TOKEN)
if (success) {
Log.i(TAG, "Successfully retrieved messages.")

View File

@@ -79,6 +79,7 @@ class IncomingMessageObserver(private val context: Application) {
private val decryptionDrainedListeners: MutableList<Runnable> = CopyOnWriteArrayList()
private val keepAliveTokens: MutableMap<String, Long> = mutableMapOf()
private val keepAlivePurgeCallbacks: MutableMap<String, MutableList<Runnable>> = mutableMapOf()
private val lock: ReentrantLock = ReentrantLock()
private val connectionNecessarySemaphore = Semaphore(0)
@@ -181,12 +182,15 @@ class IncomingMessageObserver(private val context: Application) {
timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime
val keepAliveCutoffTime = System.currentTimeMillis() - keepAliveTokenMaxAge
val removedKeepAliveToken = keepAliveTokens.entries.removeIf { (_, createTime) -> createTime < keepAliveCutoffTime }
if (removedKeepAliveToken) {
Log.d(TAG, "Removed old keep web socket open requests.")
}
keepAliveEntries = keepAliveTokens.entries.map { it.key to it.value }.toImmutableSet()
keepAliveEntries = keepAliveTokens.entries.mapNotNull { (key, createTime) ->
if (createTime < keepAliveCutoffTime) {
Log.d(TAG, "Removed old keep web socket keep alive token $key")
keepAlivePurgeCallbacks.remove(key)?.forEach { it.run() }
null
} else {
key to createTime
}
}.toImmutableSet()
}
val registered = SignalStore.account().isRegistered
@@ -235,9 +239,16 @@ class IncomingMessageObserver(private val context: Application) {
ApplicationDependencies.getSignalWebSocket().disconnect()
}
fun registerKeepAliveToken(key: String) {
@JvmOverloads
fun registerKeepAliveToken(key: String, runnable: Runnable? = null) {
lock.withLock {
keepAliveTokens[key] = System.currentTimeMillis()
if (runnable != null) {
if (!keepAlivePurgeCallbacks.containsKey(key)) {
keepAlivePurgeCallbacks[key] = ArrayList()
}
keepAlivePurgeCallbacks[key]?.add(runnable)
}
lastInteractionTime = System.currentTimeMillis()
connectionNecessarySemaphore.release()
}
@@ -246,6 +257,7 @@ class IncomingMessageObserver(private val context: Application) {
fun removeKeepAliveToken(key: String) {
lock.withLock {
keepAliveTokens.remove(key)
keepAlivePurgeCallbacks.remove(key)
lastInteractionTime = System.currentTimeMillis()
connectionNecessarySemaphore.release()
}

View File

@@ -45,7 +45,7 @@ object WebSocketDrainer {
* Also, if it is discovered that it's unlikely that we'll be able to fetch messages (i.e. no network), then the timeout may be reduced.
*/
@WorkerThread
fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long): Boolean {
fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long, keepAliveToken: String = KEEP_ALIVE_TOKEN): Boolean {
Log.d(TAG, "blockUntilDrainedAndProcessed() requestedWebsocketDrainTimeout: $requestedWebsocketDrainTimeoutMs ms")
var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs
@@ -66,16 +66,13 @@ object WebSocketDrainer {
websocketDrainTimeout = NO_NETWORK_WEBSOCKET_TIMEOUT
}
incomingMessageObserver.registerKeepAliveToken(KEEP_ALIVE_TOKEN)
val wakeLockTag = WAKELOCK_PREFIX + System.currentTimeMillis()
val wakeLock = WakeLockUtil.acquire(ApplicationDependencies.getApplication(), PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag)
return try {
drainAndProcess(websocketDrainTimeout, incomingMessageObserver)
drainAndProcess(websocketDrainTimeout, incomingMessageObserver, keepAliveToken)
} finally {
WakeLockUtil.release(wakeLock, wakeLockTag)
incomingMessageObserver.removeKeepAliveToken(KEEP_ALIVE_TOKEN)
}
}
@@ -86,7 +83,7 @@ object WebSocketDrainer {
* so that we know the queue has been drained.
*/
@WorkerThread
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver): Boolean {
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver, keepAliveToken: String): Boolean {
val stopwatch = Stopwatch("websocket-strategy")
val jobManager = ApplicationDependencies.getJobManager()
@@ -97,7 +94,7 @@ object WebSocketDrainer {
queueListener
)
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout)
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout, keepAliveToken)
if (!successfullyDrained) {
return false
}
@@ -119,25 +116,33 @@ object WebSocketDrainer {
return true
}
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long): Boolean {
val latch = CountDownLatch(1)
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
override fun run() {
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long, keepAliveToken: String): Boolean {
try {
val latch = CountDownLatch(1)
var success = false
incomingMessageObserver.registerKeepAliveToken(keepAliveToken) {
Log.w(TAG, "Keep alive token purged")
latch.countDown()
incomingMessageObserver.removeDecryptionDrainedListener(this)
}
})
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
override fun run() {
success = true
latch.countDown()
incomingMessageObserver.removeDecryptionDrainedListener(this)
}
})
return try {
if (latch.await(timeoutMs, TimeUnit.MILLISECONDS)) {
true
} else {
Log.w(TAG, "Hit timeout while waiting for decryptions to drain!")
return try {
if (!latch.await(timeoutMs, TimeUnit.MILLISECONDS)) {
Log.w(TAG, "Hit timeout while waiting for decryptions to drain!")
}
success
} catch (e: InterruptedException) {
Log.w(TAG, "Interrupted!", e)
false
}
} catch (e: InterruptedException) {
Log.w(TAG, "Interrupted!", e)
false
} finally {
incomingMessageObserver.removeKeepAliveToken(keepAliveToken)
}
}