diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManager.java index 257edc2f7..dc87c0c2f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManager.java @@ -192,7 +192,7 @@ public class RedisMessageAvailabilityManager extends RedisClusterPubSubAdapter handleClientDisconnected(final UUID accountIdentifier, final byte deviceId) { + public CompletionStage handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final MessageAvailabilityListener listener) { if (pubSubConnection == null) { throw new IllegalStateException("WebSocket connection event manager not started"); } - final AtomicReference> unsubscribeFuture = new AtomicReference<>(); + final AtomicReference> unsubscribeFuture = new AtomicReference<>(CompletableFuture.completedFuture(null)); // Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In // particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says: @@ -219,6 +219,11 @@ public class RedisMessageAvailabilityManager extends RedisClusterPubSubAdapter { + if (listener != existingListener) { + // the listener was already replaced, ignore this event completely + return existingListener; + } + unsubscribeFuture.set(CompletableFuture.supplyAsync(() -> pubSubConnection.withPubSubConnection(connection -> connection.async().sunsubscribe(getClientEventChannel(accountIdentifier, deviceId))) .thenRun(Util.NOOP), asyncOperationQueueingExecutor) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java index 6223f7929..40fc63e26 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java @@ -368,7 +368,7 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow } // Stop receiving signals about new messages/conflicting consumers - redisMessageAvailabilityManager.handleClientDisconnected(accountIdentifier, device.getId()); + redisMessageAvailabilityManager.handleClientDisconnected(accountIdentifier, device.getId(), this); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManagerTest.java index 2c0343424..9912f5d9e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/RedisMessageAvailabilityManagerTest.java @@ -23,6 +23,7 @@ import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; @@ -108,15 +109,19 @@ class RedisMessageAvailabilityManagerTest { final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false); - localEventManager.handleClientConnected(accountIdentifier, deviceId, new MessageAvailabilityAdapter() { + final AtomicReference firstListener = new AtomicReference<>(); + firstListener.set(new MessageAvailabilityAdapter() { @Override public void handleConflictingMessageConsumer() { synchronized (firstListenerDisplaced) { + localEventManager.handleClientDisconnected(accountIdentifier, deviceId, firstListener.get()); firstListenerDisplaced.set(true); firstListenerDisplaced.notifyAll(); } } - }).toCompletableFuture().join(); + }); + + localEventManager.handleClientConnected(accountIdentifier, deviceId, firstListener.get()).toCompletableFuture().join(); assertFalse(firstListenerDisplaced.get()); assertFalse(secondListenerDisplaced.get()); @@ -139,6 +144,7 @@ class RedisMessageAvailabilityManagerTest { assertTrue(firstListenerDisplaced.get()); assertFalse(secondListenerDisplaced.get()); + assertTrue(displacingManager.isLocallyPresent(accountIdentifier, deviceId)); } @Test @@ -149,14 +155,15 @@ class RedisMessageAvailabilityManagerTest { assertFalse(localEventManager.isLocallyPresent(accountIdentifier, deviceId)); assertFalse(remoteEventManager.isLocallyPresent(accountIdentifier, deviceId)); - localEventManager.handleClientConnected(accountIdentifier, deviceId, new MessageAvailabilityAdapter()) + final MessageAvailabilityAdapter localListener = new MessageAvailabilityAdapter(); + localEventManager.handleClientConnected(accountIdentifier, deviceId, localListener) .toCompletableFuture() .join(); assertTrue(localEventManager.isLocallyPresent(accountIdentifier, deviceId)); assertFalse(remoteEventManager.isLocallyPresent(accountIdentifier, deviceId)); - localEventManager.handleClientDisconnected(accountIdentifier, deviceId) + localEventManager.handleClientDisconnected(accountIdentifier, deviceId, localListener) .toCompletableFuture() .join(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java index 9d4894fd6..79809844d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java @@ -115,7 +115,7 @@ class RedisDynamoDbMessagePublisherTest { new RedisDynamoDbMessagePublisher(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, accountIdentifier, device); verify(redisMessageAvailabilityManager, never()).handleClientConnected(eq(accountIdentifier), eq(deviceId), any()); - verify(redisMessageAvailabilityManager, never()).handleClientDisconnected(eq(accountIdentifier), eq(deviceId)); + verify(redisMessageAvailabilityManager, never()).handleClientDisconnected(eq(accountIdentifier), eq(deviceId), any()); } { @@ -127,7 +127,7 @@ class RedisDynamoDbMessagePublisherTest { JdkFlowAdapter.flowPublisherToFlux(messagePublisher).subscribe(); verify(redisMessageAvailabilityManager).handleClientConnected(eq(accountIdentifier), eq(deviceId), any()); - verify(redisMessageAvailabilityManager, never()).handleClientDisconnected(eq(accountIdentifier), eq(deviceId)); + verify(redisMessageAvailabilityManager, never()).handleClientDisconnected(eq(accountIdentifier), eq(deviceId), any()); } { @@ -140,7 +140,7 @@ class RedisDynamoDbMessagePublisherTest { disposable.dispose(); verify(redisMessageAvailabilityManager).handleClientConnected(eq(accountIdentifier), eq(deviceId), any()); - verify(redisMessageAvailabilityManager).handleClientDisconnected(eq(accountIdentifier), eq(deviceId)); + verify(redisMessageAvailabilityManager).handleClientDisconnected(eq(accountIdentifier), eq(deviceId), any()); } } @@ -327,7 +327,7 @@ class RedisDynamoDbMessagePublisherTest { .verify(); verify(redisMessageAvailabilityManager, timeout(1_000)).handleClientConnected(DESTINATION_SERVICE_IDENTIFIER.uuid(), destinationDevice.getId(), messagePublisher); - verify(redisMessageAvailabilityManager, timeout(1_000)).handleClientDisconnected(DESTINATION_SERVICE_IDENTIFIER.uuid(), destinationDevice.getId()); + verify(redisMessageAvailabilityManager, timeout(1_000)).handleClientDisconnected(DESTINATION_SERVICE_IDENTIFIER.uuid(), destinationDevice.getId(), messagePublisher); } @ParameterizedTest