mirror of
https://github.com/signalapp/Signal-Android.git
synced 2026-02-28 13:48:12 +00:00
Fail websocket drain if keepalive token is removed.
This commit is contained in:
@@ -43,6 +43,8 @@ object FcmFetchManager {
|
||||
private val TAG = Log.tag(FcmFetchManager::class.java)
|
||||
private val EXECUTOR = SerialMonoLifoExecutor(SignalExecutors.UNBOUNDED)
|
||||
|
||||
private val KEEP_ALIVE_TOKEN = "FcmFetch"
|
||||
|
||||
val WEBSOCKET_DRAIN_TIMEOUT = 5.minutes.inWholeMilliseconds
|
||||
|
||||
@Volatile
|
||||
@@ -140,7 +142,7 @@ object FcmFetchManager {
|
||||
|
||||
@JvmStatic
|
||||
fun retrieveMessages(context: Context): Boolean {
|
||||
val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT)
|
||||
val success = WebSocketDrainer.blockUntilDrainedAndProcessed(WEBSOCKET_DRAIN_TIMEOUT, KEEP_ALIVE_TOKEN)
|
||||
|
||||
if (success) {
|
||||
Log.i(TAG, "Successfully retrieved messages.")
|
||||
|
||||
@@ -79,6 +79,7 @@ class IncomingMessageObserver(private val context: Application) {
|
||||
|
||||
private val decryptionDrainedListeners: MutableList<Runnable> = CopyOnWriteArrayList()
|
||||
private val keepAliveTokens: MutableMap<String, Long> = mutableMapOf()
|
||||
private val keepAlivePurgeCallbacks: MutableMap<String, MutableList<Runnable>> = mutableMapOf()
|
||||
|
||||
private val lock: ReentrantLock = ReentrantLock()
|
||||
private val connectionNecessarySemaphore = Semaphore(0)
|
||||
@@ -181,12 +182,15 @@ class IncomingMessageObserver(private val context: Application) {
|
||||
timeIdle = if (appVisibleSnapshot) 0 else System.currentTimeMillis() - lastInteractionTime
|
||||
|
||||
val keepAliveCutoffTime = System.currentTimeMillis() - keepAliveTokenMaxAge
|
||||
val removedKeepAliveToken = keepAliveTokens.entries.removeIf { (_, createTime) -> createTime < keepAliveCutoffTime }
|
||||
if (removedKeepAliveToken) {
|
||||
Log.d(TAG, "Removed old keep web socket open requests.")
|
||||
}
|
||||
|
||||
keepAliveEntries = keepAliveTokens.entries.map { it.key to it.value }.toImmutableSet()
|
||||
keepAliveEntries = keepAliveTokens.entries.mapNotNull { (key, createTime) ->
|
||||
if (createTime < keepAliveCutoffTime) {
|
||||
Log.d(TAG, "Removed old keep web socket keep alive token $key")
|
||||
keepAlivePurgeCallbacks.remove(key)?.forEach { it.run() }
|
||||
null
|
||||
} else {
|
||||
key to createTime
|
||||
}
|
||||
}.toImmutableSet()
|
||||
}
|
||||
|
||||
val registered = SignalStore.account().isRegistered
|
||||
@@ -235,9 +239,16 @@ class IncomingMessageObserver(private val context: Application) {
|
||||
ApplicationDependencies.getSignalWebSocket().disconnect()
|
||||
}
|
||||
|
||||
fun registerKeepAliveToken(key: String) {
|
||||
@JvmOverloads
|
||||
fun registerKeepAliveToken(key: String, runnable: Runnable? = null) {
|
||||
lock.withLock {
|
||||
keepAliveTokens[key] = System.currentTimeMillis()
|
||||
if (runnable != null) {
|
||||
if (!keepAlivePurgeCallbacks.containsKey(key)) {
|
||||
keepAlivePurgeCallbacks[key] = ArrayList()
|
||||
}
|
||||
keepAlivePurgeCallbacks[key]?.add(runnable)
|
||||
}
|
||||
lastInteractionTime = System.currentTimeMillis()
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
@@ -246,6 +257,7 @@ class IncomingMessageObserver(private val context: Application) {
|
||||
fun removeKeepAliveToken(key: String) {
|
||||
lock.withLock {
|
||||
keepAliveTokens.remove(key)
|
||||
keepAlivePurgeCallbacks.remove(key)
|
||||
lastInteractionTime = System.currentTimeMillis()
|
||||
connectionNecessarySemaphore.release()
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ object WebSocketDrainer {
|
||||
* Also, if it is discovered that it's unlikely that we'll be able to fetch messages (i.e. no network), then the timeout may be reduced.
|
||||
*/
|
||||
@WorkerThread
|
||||
fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long): Boolean {
|
||||
fun blockUntilDrainedAndProcessed(requestedWebsocketDrainTimeoutMs: Long, keepAliveToken: String = KEEP_ALIVE_TOKEN): Boolean {
|
||||
Log.d(TAG, "blockUntilDrainedAndProcessed() requestedWebsocketDrainTimeout: $requestedWebsocketDrainTimeoutMs ms")
|
||||
|
||||
var websocketDrainTimeout = requestedWebsocketDrainTimeoutMs
|
||||
@@ -66,16 +66,13 @@ object WebSocketDrainer {
|
||||
websocketDrainTimeout = NO_NETWORK_WEBSOCKET_TIMEOUT
|
||||
}
|
||||
|
||||
incomingMessageObserver.registerKeepAliveToken(KEEP_ALIVE_TOKEN)
|
||||
|
||||
val wakeLockTag = WAKELOCK_PREFIX + System.currentTimeMillis()
|
||||
val wakeLock = WakeLockUtil.acquire(ApplicationDependencies.getApplication(), PowerManager.PARTIAL_WAKE_LOCK, websocketDrainTimeout + QUEUE_TIMEOUT, wakeLockTag)
|
||||
|
||||
return try {
|
||||
drainAndProcess(websocketDrainTimeout, incomingMessageObserver)
|
||||
drainAndProcess(websocketDrainTimeout, incomingMessageObserver, keepAliveToken)
|
||||
} finally {
|
||||
WakeLockUtil.release(wakeLock, wakeLockTag)
|
||||
incomingMessageObserver.removeKeepAliveToken(KEEP_ALIVE_TOKEN)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +83,7 @@ object WebSocketDrainer {
|
||||
* so that we know the queue has been drained.
|
||||
*/
|
||||
@WorkerThread
|
||||
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver): Boolean {
|
||||
private fun drainAndProcess(timeout: Long, incomingMessageObserver: IncomingMessageObserver, keepAliveToken: String): Boolean {
|
||||
val stopwatch = Stopwatch("websocket-strategy")
|
||||
|
||||
val jobManager = ApplicationDependencies.getJobManager()
|
||||
@@ -97,7 +94,7 @@ object WebSocketDrainer {
|
||||
queueListener
|
||||
)
|
||||
|
||||
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout)
|
||||
val successfullyDrained = blockUntilWebsocketDrained(incomingMessageObserver, timeout, keepAliveToken)
|
||||
if (!successfullyDrained) {
|
||||
return false
|
||||
}
|
||||
@@ -119,25 +116,33 @@ object WebSocketDrainer {
|
||||
return true
|
||||
}
|
||||
|
||||
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long): Boolean {
|
||||
val latch = CountDownLatch(1)
|
||||
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
|
||||
override fun run() {
|
||||
private fun blockUntilWebsocketDrained(incomingMessageObserver: IncomingMessageObserver, timeoutMs: Long, keepAliveToken: String): Boolean {
|
||||
try {
|
||||
val latch = CountDownLatch(1)
|
||||
var success = false
|
||||
incomingMessageObserver.registerKeepAliveToken(keepAliveToken) {
|
||||
Log.w(TAG, "Keep alive token purged")
|
||||
latch.countDown()
|
||||
incomingMessageObserver.removeDecryptionDrainedListener(this)
|
||||
}
|
||||
})
|
||||
incomingMessageObserver.addDecryptionDrainedListener(object : Runnable {
|
||||
override fun run() {
|
||||
success = true
|
||||
latch.countDown()
|
||||
incomingMessageObserver.removeDecryptionDrainedListener(this)
|
||||
}
|
||||
})
|
||||
|
||||
return try {
|
||||
if (latch.await(timeoutMs, TimeUnit.MILLISECONDS)) {
|
||||
true
|
||||
} else {
|
||||
Log.w(TAG, "Hit timeout while waiting for decryptions to drain!")
|
||||
return try {
|
||||
if (!latch.await(timeoutMs, TimeUnit.MILLISECONDS)) {
|
||||
Log.w(TAG, "Hit timeout while waiting for decryptions to drain!")
|
||||
}
|
||||
success
|
||||
} catch (e: InterruptedException) {
|
||||
Log.w(TAG, "Interrupted!", e)
|
||||
false
|
||||
}
|
||||
} catch (e: InterruptedException) {
|
||||
Log.w(TAG, "Interrupted!", e)
|
||||
false
|
||||
} finally {
|
||||
incomingMessageObserver.removeKeepAliveToken(keepAliveToken)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user