diff --git a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt index a68a74f74a..281c8e628a 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/gcm/FcmFetchManager.kt @@ -43,6 +43,8 @@ object FcmFetchManager { private val TAG = Log.tag(FcmFetchManager::class.java) private val EXECUTOR = SerialMonoLifoExecutor(SignalExecutors.UNBOUNDED) + private val KEEP_ALIVE_TOKEN = "FcmFetch" + val WEBSOCKET_DRAIN_TIMEOUT = 5.minutes.inWholeMilliseconds @Volatile @@ -140,7 +142,7 @@ object FcmFetchManager { @JvmStatic fun retrieveMessages(context: Context): Boolean { - val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT) + val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT, KEEP_ALIVE_TOKEN) if (success) { Log.i(TAG, "Successfully retrieved messages.") diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt index 3b616eea66..004e3ff5af 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/IncomingMessageObserver.kt @@ -79,6 +79,7 @@ class IncomingMessageObserver(private val context: Application) { private val decryptionDrainedListeners: MutableList = CopyOnWriteArrayList() private val keepAliveTokens: MutableMap = mutableMapOf() + private val keepAlivePurgeCallbacks: MutableMap> = mutableMapOf() private val lock: ReentrantLock = ReentrantLock() private val connectionNecessarySemaphore = Semaphore(0) @@ -181,12 +182,15 @@ class IncomingMessageObserver(private val context: Application) { timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime val keepAliveCutoffTime = System.currentTimeMillis() - keepAliveTokenMaxAge - val removedKeepAliveToken = keepAliveTokens.entries.removeIf { (_, createTime) -> createTime < keepAliveCutoffTime } - if (removedKeepAliveToken) { - Log.d(TAG, "Removed old keep web socket open requests.") - } - - keepAliveEntries = keepAliveTokens.entries.map { it.key to it.value }.toImmutableSet() + keepAliveEntries = keepAliveTokens.entries.mapNotNull { (key, createTime) -> + if (createTime < keepAliveCutoffTime) { + Log.d(TAG, "Removed old keep web socket keep alive token $key") + keepAlivePurgeCallbacks.remove(key)?.forEach { it.run() } + null + } else { + key to createTime + } + }.toImmutableSet() } val registered = SignalStore.account().isRegistered @@ -235,9 +239,16 @@ class IncomingMessageObserver(private val context: Application) { ApplicationDependencies.getSignalWebSocket().disconnect() } - fun registerKeepAliveToken(key: String) { + @JvmOverloads + fun registerKeepAliveToken(key: String, runnable: Runnable? = null) { lock.withLock { keepAliveTokens[key] = System.currentTimeMillis() + if (runnable != null) { + if (!keepAlivePurgeCallbacks.containsKey(key)) { + keepAlivePurgeCallbacks[key] = ArrayList() + } + keepAlivePurgeCallbacks[key]?.add(runnable) + } lastInteractionTime = System.currentTimeMillis() connectionNecessarySemaphore.release() } @@ -246,6 +257,7 @@ class IncomingMessageObserver(private val context: Application) { fun removeKeepAliveToken(key: String) { lock.withLock { keepAliveTokens.remove(key) + keepAlivePurgeCallbacks.remove(key) lastInteractionTime = System.currentTimeMillis() connectionNecessarySemaphore.release() } diff --git a/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt b/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt index 7a33f71b4f..af974161f2 100644 --- a/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt +++ b/app/src/main/java/org/thoughtcrime/securesms/messages/WebSocketDrainer.kt @@ -45,7 +45,7 @@ object WebSocketDrainer { * Also, if it is discovered that it's unlikely that we'll be able to fetch messages (i.e. no network), then the timeout may be reduced. */ @WorkerThread - fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long): Boolean { + fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long, keepAliveToken: String = KEEP_ALIVE_TOKEN): Boolean { Log.d(TAG, "blockUntilDrainedAndProcessed() requestedWebsocketDrainTimeout: $requestedWebsocketDrainTimeoutMs ms") var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs @@ -66,16 +66,13 @@ object WebSocketDrainer { websocketDrainTimeout = NO_NETWORK_WEBSOCKET_TIMEOUT } - incomingMessageObserver.registerKeepAliveToken(KEEP_ALIVE_TOKEN) - val wakeLockTag = WAKELOCK_PREFIX + System.currentTimeMillis() val wakeLock = WakeLockUtil.acquire(ApplicationDependencies.getApplication(), PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag) return try { - drainAndProcess(websocketDrainTimeout, incomingMessageObserver) + drainAndProcess(websocketDrainTimeout, incomingMessageObserver, keepAliveToken) } finally { WakeLockUtil.release(wakeLock, wakeLockTag) - incomingMessageObserver.removeKeepAliveToken(KEEP_ALIVE_TOKEN) } } @@ -86,7 +83,7 @@ object WebSocketDrainer { * so that we know the queue has been drained. */ @WorkerThread - private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver): Boolean { + private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver, keepAliveToken: String): Boolean { val stopwatch = Stopwatch("websocket-strategy") val jobManager = ApplicationDependencies.getJobManager() @@ -97,7 +94,7 @@ object WebSocketDrainer { queueListener ) - val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout) + val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout, keepAliveToken) if (!successfullyDrained) { return false } @@ -119,25 +116,33 @@ object WebSocketDrainer { return true } - private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long): Boolean { - val latch = CountDownLatch(1) - incomingMessageObserver.addDecryptionDrainedListener(object : Runnable { - override fun run() { + private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long, keepAliveToken: String): Boolean { + try { + val latch = CountDownLatch(1) + var success = false + incomingMessageObserver.registerKeepAliveToken(keepAliveToken) { + Log.w(TAG, "Keep alive token purged") latch.countDown() - incomingMessageObserver.removeDecryptionDrainedListener(this) } - }) + incomingMessageObserver.addDecryptionDrainedListener(object : Runnable { + override fun run() { + success = true + latch.countDown() + incomingMessageObserver.removeDecryptionDrainedListener(this) + } + }) - return try { - if (latch.await(timeoutMs, TimeUnit.MILLISECONDS)) { - true - } else { - Log.w(TAG, "Hit timeout while waiting for decryptions to drain!") + return try { + if (!latch.await(timeoutMs, TimeUnit.MILLISECONDS)) { + Log.w(TAG, "Hit timeout while waiting for decryptions to drain!") + } + success + } catch (e: InterruptedException) { + Log.w(TAG, "Interrupted!", e) false } - } catch (e: InterruptedException) { - Log.w(TAG, "Interrupted!", e) - false + } finally { + incomingMessageObserver.removeKeepAliveToken(keepAliveToken) } }