diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 475339d7d..82166c2b5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -626,7 +626,8 @@ public class WhisperServerService extends Application dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion()); SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration()); - DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); + final GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager(); + DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, grpcClientConnectionManager, disconnectionRequestListenerExecutor); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, asyncCdnS3Client, config.getCdnConfiguration().bucket()); MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, messageDeletionAsyncExecutor, clock); @@ -675,8 +676,6 @@ public class WhisperServerService extends Application deviceIds); + void handleDisconnectionRequest(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java index bb6bac3d6..942c58a7c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java @@ -5,21 +5,27 @@ package org.whispersystems.textsecuregcm.auth; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.pubsub.RedisPubSubAdapter; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletionStage; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; @@ -37,11 +43,11 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil; public class DisconnectionRequestManager extends RedisPubSubAdapter implements Managed { private final FaultTolerantRedisClient pubSubClient; + private final GrpcClientConnectionManager grpcClientConnectionManager; private final Executor listenerEventExecutor; - // We expect just a couple listeners to get added at startup time and not at all at steady-state. There are several - // reasonable ways to model this, but a copy-on-write list gives us good flexibility with minimal performance cost. - private final List listeners = new CopyOnWriteArrayList<>(); + private final Map> listeners = + new ConcurrentHashMap<>(); @Nullable private FaultTolerantPubSubConnection pubSubConnection; @@ -56,10 +62,14 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter { + final List listeners = + existingListeners == null ? new ArrayList<>() : existingListeners; + + listeners.add(listener); + + return listeners; + }); + } + + /** + * Removes a listener for disconnection requests for a specific authenticated device. + * + * @param accountIdentifier TODO + * @param deviceId TODO + * @param listener the listener to remove + */ + public void removeListener(final UUID accountIdentifier, final byte deviceId, final DisconnectionRequestListener listener) { + listeners.computeIfPresent(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), (_, existingListeners) -> { + existingListeners.remove(listener); + + return existingListeners.isEmpty() ? null : existingListeners; + }); + } + + @VisibleForTesting + List getListeners(final UUID accountIdentifier, final byte deviceId) { + return listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList()); } /** @@ -154,12 +192,17 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter listener.handleDisconnectionRequest(accountIdentifier, deviceIds)); - } catch (final Exception e) { - logger.warn("Listener failed to handle disconnection request", e); - } - } + deviceIds.forEach(deviceId -> { + grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(accountIdentifier, deviceId)); + + listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList()) + .forEach(listener -> listenerEventExecutor.execute(() -> { + try { + listener.handleDisconnectionRequest(); + } catch (final Exception e) { + logger.warn("Listener failed to handle disconnection request", e); + } + })); + }); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java index fa35ffcb6..ad12da9fe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java @@ -10,19 +10,16 @@ import io.netty.channel.local.LocalChannel; import io.netty.util.AttributeKey; import java.net.InetAddress; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Optional; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; import org.whispersystems.textsecuregcm.grpc.RequestAttributes; @@ -45,7 +42,7 @@ import org.whispersystems.textsecuregcm.util.ClosableEpoch; * Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may * be called from any application code. */ -public class GrpcClientConnectionManager implements DisconnectionRequestListener { +public class GrpcClientConnectionManager { private final Map remoteChannelsByLocalAddress = new ConcurrentHashMap<>(); private final Map> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>(); @@ -62,7 +59,7 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener static final AttributeKey EPOCH_ATTRIBUTE_KEY = AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch"); - private static OutboundCloseErrorMessage SERVER_CLOSED = + private static final OutboundCloseErrorMessage SERVER_CLOSED = new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"); private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); @@ -268,11 +265,4 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener })); }); } - - @Override - public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection deviceIds) { - deviceIds.stream() - .map(deviceId -> new AuthenticatedDevice(accountIdentifier, deviceId)) - .forEach(this::closeConnection); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventListener.java index 90078cfc0..6bc6af1f4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventListener.java @@ -22,11 +22,8 @@ public interface WebSocketConnectionEventListener { void handleMessagesPersisted(); /** - * Indicates that the client's presence has been displaced and the listener should close the client's underlying - * network connection. - * - * @param connectedElsewhere if {@code true}, indicates that the client's presence has been displaced by another - * connection from the same client + * Indicates a newer instance of this client has started reading messages and the listener should close this client's + * underlying network connection. */ - void handleConnectionDisplaced(boolean connectedElsewhere); + void handleConflictingMessageReader(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java index b9bc5e523..05fdc1877 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManager.java @@ -17,11 +17,9 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -32,8 +30,6 @@ import java.util.function.Function; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; -import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; @@ -56,10 +52,9 @@ import org.whispersystems.textsecuregcm.util.Util; * servers, but cannot guarantee at-most-one behavior. * * @see WebSocketConnectionEventListener - * @see org.whispersystems.textsecuregcm.storage.MessagesManager#insert(UUID, byte, MessageProtos.Envelope) + * @see org.whispersystems.textsecuregcm.storage.MessagesManager#insert(UUID, Map) */ -public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter implements Managed, - DisconnectionRequestListener { +public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter implements Managed { private final AccountsManager accountsManager; private final PushNotificationManager pushNotificationManager; @@ -145,7 +140,7 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter displacedListener.get().handleConnectionDisplaced(true)); + listenerEventExecutor.execute(() -> displacedListener.get().handleConflictingMessageReader()); } return subscribeFuture.get() @@ -260,14 +255,6 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter deviceIds) { - deviceIds.stream() - .map(deviceId -> listenersByAccountAndDeviceIdentifier.get(new AccountAndDeviceIdentifier(accountIdentifier, deviceId))) - .filter(Objects::nonNull) - .forEach(listener -> listener.handleConnectionDisplaced(false)); - } - @VisibleForTesting void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) { final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent); @@ -339,7 +326,7 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter listener.handleConnectionDisplaced(true)); + listenerEventExecutor.execute(listener::handleConflictingMessageReader); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index e8f79adaa..9048f56fa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -13,7 +13,9 @@ import java.util.concurrent.ScheduledExecutorService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; @@ -48,6 +50,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final PushNotificationManager pushNotificationManager; private final PushNotificationScheduler pushNotificationScheduler; private final WebSocketConnectionEventManager webSocketConnectionEventManager; + private final DisconnectionRequestManager disconnectionRequestManager; private final ScheduledExecutorService scheduledExecutorService; private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; @@ -58,17 +61,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; public AuthenticatedConnectListener( - AccountsManager accountsManager, - ReceiptSender receiptSender, - MessagesManager messagesManager, - MessageMetrics messageMetrics, - PushNotificationManager pushNotificationManager, - PushNotificationScheduler pushNotificationScheduler, - WebSocketConnectionEventManager webSocketConnectionEventManager, - ScheduledExecutorService scheduledExecutorService, - Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager, - MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, + final AccountsManager accountsManager, + final ReceiptSender receiptSender, + final MessagesManager messagesManager, + final MessageMetrics messageMetrics, + final PushNotificationManager pushNotificationManager, + final PushNotificationScheduler pushNotificationScheduler, + final WebSocketConnectionEventManager webSocketConnectionEventManager, + final DisconnectionRequestManager disconnectionRequestManager, + final ScheduledExecutorService scheduledExecutorService, + final Scheduler messageDeliveryScheduler, + final ClientReleaseManager clientReleaseManager, + final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { this.accountsManager = accountsManager; @@ -78,6 +82,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.pushNotificationManager = pushNotificationManager; this.pushNotificationScheduler = pushNotificationScheduler; this.webSocketConnectionEventManager = webSocketConnectionEventManager; + this.disconnectionRequestManager = disconnectionRequestManager; this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; @@ -104,7 +109,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); final Optional maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.accountIdentifier()); - final Optional maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId()));; + final Optional maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId()); @@ -127,7 +132,15 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { messageDeliveryLoopMonitor, experimentEnrollmentManager); - context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { + disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), + maybeAuthenticatedDevice.get().getId(), + connection); + + context.addWebsocketClosedListener((_, _, _) -> { + disconnectionRequestManager.removeListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), + maybeAuthenticatedDevice.get().getId(), + connection); + // We begin the shutdown process by removing this client's "presence," which means it will again begin to // receive push notifications for inbound messages. We should do this first because, at this point, the // connection has already closed and attempts to actually deliver a message via the connection will not succeed. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 6d32574f1..a9d8148dc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -37,6 +37,7 @@ import org.eclipse.jetty.util.StaticException; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; @@ -65,7 +66,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; -public class WebSocketConnection implements WebSocketConnectionEventListener { +public class WebSocketConnection implements WebSocketConnectionEventListener, DisconnectionRequestListener { private static final DistributionSummary messageTime = Metrics.summary( name(MessageController.class, "messageDeliveryDuration")); @@ -506,24 +507,23 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { } @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { + public void handleConflictingMessageReader() { + closeConnection(4409, "Connected elsewhere"); + } + + @Override + public void handleDisconnectionRequest() { + closeConnection(4401, "Reauthentication required"); + } + + private void closeConnection(final int code, final String message) { final Tags tags = Tags.of( UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere))); + // TODO We should probably just use the status code directly + Tag.of("connectedElsewhere", String.valueOf(code == 4409))); Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); - final int code; - final String message; - - if (connectedElsewhere) { - code = 4409; - message = "Connected elsewhere"; - } else { - code = 4401; - message = "Reauthentication required"; - } - client.close(code, message); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index da5229149..bc62b8d2a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controll import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.controllers.SecureValueRecoveryBController; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.push.APNSender; @@ -254,7 +255,8 @@ record CommandDependencies( () -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion()); SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration()); - DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); + GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager(); + DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, grpcClientConnectionManager, disconnectionRequestListenerExecutor); MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, asyncCdnS3Client, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java index ccb579530..acb341979 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java @@ -6,20 +6,21 @@ package org.whispersystems.textsecuregcm.auth; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.util.Collection; -import java.util.HashSet; import java.util.List; -import java.util.Set; import java.util.UUID; -import java.util.concurrent.CountDownLatch; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.storage.Account; @@ -28,42 +29,19 @@ import org.whispersystems.textsecuregcm.storage.Device; @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class DisconnectionRequestManagerTest { + private GrpcClientConnectionManager grpcClientConnectionManager; private DisconnectionRequestManager disconnectionRequestManager; @RegisterExtension static final RedisServerExtension REDIS_EXTENSION = RedisServerExtension.builder().build(); - private static class DisconnectionRequestTestListener implements DisconnectionRequestListener { - - private final CountDownLatch requestLatch = new CountDownLatch(1); - - private UUID accountIdentifier; - private Collection deviceIds; - - @Override - public void handleDisconnectionRequest(final UUID accountIdentifier, final Collection deviceIds) { - this.accountIdentifier = accountIdentifier; - this.deviceIds = deviceIds; - - requestLatch.countDown(); - } - - public UUID getAccountIdentifier() { - return accountIdentifier; - } - - public Collection getDeviceIds() { - return deviceIds; - } - - public void waitForRequest() throws InterruptedException { - requestLatch.await(); - } - } - @BeforeEach void setUp() { - disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(), Runnable::run); + grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); + + disconnectionRequestManager = + new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(), grpcClientConnectionManager, Runnable::run); + disconnectionRequestManager.start(); } @@ -73,43 +51,98 @@ class DisconnectionRequestManagerTest { } @Test - void requestDisconnection() throws InterruptedException { + void addRemoveListener() { final UUID accountIdentifier = UUID.randomUUID(); - final List deviceIds = List.of(Device.PRIMARY_ID, (byte) (Device.PRIMARY_ID + 1)); + final byte deviceId = Device.PRIMARY_ID; - final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener(); + final DisconnectionRequestListener firstListener = mock(DisconnectionRequestListener.class); + final DisconnectionRequestListener secondListener = mock(DisconnectionRequestListener.class); - disconnectionRequestManager.addListener(listener); - disconnectionRequestManager.requestDisconnection(accountIdentifier, deviceIds).toCompletableFuture().join(); + assertTrue(disconnectionRequestManager.getListeners(accountIdentifier, deviceId).isEmpty()); - listener.waitForRequest(); + disconnectionRequestManager.addListener(accountIdentifier, deviceId, firstListener); - assertEquals(accountIdentifier, listener.getAccountIdentifier()); - assertEquals(deviceIds, listener.getDeviceIds()); + assertEquals(List.of(firstListener), disconnectionRequestManager.getListeners(accountIdentifier, deviceId)); + + disconnectionRequestManager.addListener(accountIdentifier, deviceId, secondListener); + + assertEquals(List.of(firstListener, secondListener), + disconnectionRequestManager.getListeners(accountIdentifier, deviceId)); + + disconnectionRequestManager.removeListener(accountIdentifier, deviceId, mock(DisconnectionRequestListener.class)); + + assertEquals(List.of(firstListener, secondListener), + disconnectionRequestManager.getListeners(accountIdentifier, deviceId)); + + disconnectionRequestManager.removeListener(accountIdentifier, deviceId, firstListener); + + assertEquals(List.of(secondListener), disconnectionRequestManager.getListeners(accountIdentifier, deviceId)); } @Test - void requestDisconnectionAllDevices() throws InterruptedException { + void requestDisconnection() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte primaryDeviceId = Device.PRIMARY_ID; + final byte linkedDeviceId = primaryDeviceId + 1; + + final UUID otherAccountIdentifier = UUID.randomUUID(); + final byte otherDeviceId = linkedDeviceId + 1; + + final List deviceIds = List.of(primaryDeviceId, linkedDeviceId); + + final DisconnectionRequestListener primaryDeviceListener = mock(DisconnectionRequestListener.class); + final DisconnectionRequestListener linkedDeviceListener = mock(DisconnectionRequestListener.class); + + disconnectionRequestManager.addListener(accountIdentifier, primaryDeviceId, primaryDeviceListener); + disconnectionRequestManager.addListener(accountIdentifier, linkedDeviceId, linkedDeviceListener); + + disconnectionRequestManager.requestDisconnection(accountIdentifier, deviceIds).toCompletableFuture().join(); + + verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest(); + verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest(); + verify(grpcClientConnectionManager, timeout(1_000)) + .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId)); + + verify(grpcClientConnectionManager, timeout(1_000)) + .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId)); + + disconnectionRequestManager.requestDisconnection(otherAccountIdentifier, List.of(otherDeviceId)); + + verify(grpcClientConnectionManager, timeout(1_000)) + .closeConnection(new AuthenticatedDevice(otherAccountIdentifier, otherDeviceId)); + } + + @Test + void requestDisconnectionAllDevices() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte primaryDeviceId = Device.PRIMARY_ID; + final byte linkedDeviceId = primaryDeviceId + 1; + final Device primaryDevice = mock(Device.class); - when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(primaryDevice.getId()).thenReturn(primaryDeviceId); final Device linkedDevice = mock(Device.class); - when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1)); - - final UUID accountIdentifier = UUID.randomUUID(); + when(linkedDevice.getId()).thenReturn(linkedDeviceId); final Account account = mock(Account.class); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); - final DisconnectionRequestTestListener listener = new DisconnectionRequestTestListener(); + final DisconnectionRequestListener primaryDeviceListener = mock(DisconnectionRequestListener.class); + final DisconnectionRequestListener linkedDeviceListener = mock(DisconnectionRequestListener.class); + + disconnectionRequestManager.addListener(accountIdentifier, primaryDeviceId, primaryDeviceListener); + disconnectionRequestManager.addListener(accountIdentifier, linkedDeviceId, linkedDeviceListener); - disconnectionRequestManager.addListener(listener); disconnectionRequestManager.requestDisconnection(account).toCompletableFuture().join(); - listener.waitForRequest(); + verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest(); + verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest(); - assertEquals(accountIdentifier, listener.getAccountIdentifier()); - assertEquals(List.of(Device.PRIMARY_ID, (byte) (Device.PRIMARY_ID + 1)), listener.getDeviceIds()); + verify(grpcClientConnectionManager, timeout(1_000)) + .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId)); + + verify(grpcClientConnectionManager, timeout(1_000)) + .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java index b669fb126..c71a056aa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java @@ -67,7 +67,7 @@ class WebSocketConnectionEventManagerTest { } @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { + public void handleConflictingMessageReader() { } } @@ -116,15 +116,12 @@ class WebSocketConnectionEventManagerTest { final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false); final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false); - final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false); localEventManager.handleClientConnected(accountIdentifier, deviceId, new WebSocketConnectionEventAdapter() { @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { + public void handleConflictingMessageReader() { synchronized (firstListenerDisplaced) { firstListenerDisplaced.set(true); - firstListenerConnectedElsewhere.set(connectedElsewhere); - firstListenerDisplaced.notifyAll(); } } @@ -138,7 +135,7 @@ class WebSocketConnectionEventManagerTest { displacingManager.handleClientConnected(accountIdentifier, deviceId, new WebSocketConnectionEventAdapter() { @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { + public void handleConflictingMessageReader() { secondListenerDisplaced.set(true); } }).toCompletableFuture().join(); @@ -151,8 +148,6 @@ class WebSocketConnectionEventManagerTest { assertTrue(firstListenerDisplaced.get()); assertFalse(secondListenerDisplaced.get()); - - assertTrue(firstListenerConnectedElsewhere.get()); } @Test @@ -178,56 +173,6 @@ class WebSocketConnectionEventManagerTest { assertFalse(remoteEventManager.isLocallyPresent(accountIdentifier, deviceId)); } - @Test - void handleDisconnectionRequest() throws InterruptedException { - final UUID accountIdentifier = UUID.randomUUID(); - final byte firstDeviceId = Device.PRIMARY_ID; - final byte secondDeviceId = firstDeviceId + 1; - - final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false); - final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false); - - final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false); - - localEventManager.handleClientConnected(accountIdentifier, firstDeviceId, new WebSocketConnectionEventAdapter() { - @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { - synchronized (firstListenerDisplaced) { - firstListenerDisplaced.set(true); - firstListenerConnectedElsewhere.set(connectedElsewhere); - - firstListenerDisplaced.notifyAll(); - } - } - }).toCompletableFuture().join(); - - localEventManager.handleClientConnected(accountIdentifier, secondDeviceId, new WebSocketConnectionEventAdapter() { - @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { - synchronized (secondListenerDisplaced) { - secondListenerDisplaced.set(true); - secondListenerDisplaced.notifyAll(); - } - } - }).toCompletableFuture().join(); - - assertFalse(firstListenerDisplaced.get()); - assertFalse(secondListenerDisplaced.get()); - - localEventManager.handleDisconnectionRequest(accountIdentifier, List.of(firstDeviceId)); - - synchronized (firstListenerDisplaced) { - while (!firstListenerDisplaced.get()) { - firstListenerDisplaced.wait(); - } - } - - assertTrue(firstListenerDisplaced.get()); - assertFalse(secondListenerDisplaced.get()); - - assertFalse(firstListenerConnectedElsewhere.get()); - } - @Test void resubscribe() { @SuppressWarnings("unchecked") final RedisClusterPubSubCommands pubSubCommands = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index d83919f11..9f7284b3c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -170,7 +170,7 @@ class MessagePersisterIntegrationTest { } @Override - public void handleConnectionDisplaced(final boolean connectedElsewhere) { + public void handleConflictingMessageReader() { } }); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 32ff55b51..93677cb70 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -55,6 +55,7 @@ import org.junit.jupiter.api.Test; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -124,7 +125,7 @@ class WebSocketConnectionTest { new WebSocketAccountAuthenticator(accountAuthenticator); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor, + mock(WebSocketConnectionEventManager.class), mock(DisconnectionRequestManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);