Introduce and evaluate a client presence manager based on sharded pub/sub

This commit is contained in:
Jon Chambers
2024-11-05 15:51:29 -05:00
committed by GitHub
parent 60cdcf5f0c
commit 8c984cbf42
35 changed files with 1339 additions and 56 deletions

View File

@@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
@@ -61,6 +60,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -97,6 +97,7 @@ class LinkedDeviceRefreshRequirementProviderTest {
private AccountsManager accountsManager;
private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private LinkedDeviceRefreshRequirementProvider provider;
@@ -104,11 +105,12 @@ class LinkedDeviceRefreshRequirementProviderTest {
void setup() {
accountsManager = mock(AccountsManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(clientPresenceManager, provider);
new WebsocketRefreshRequestEventListener(clientPresenceManager, pubSubClientEventManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
@@ -146,6 +148,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3);
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
}
@ParameterizedTest
@@ -173,8 +179,10 @@ class LinkedDeviceRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId ->
verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
initialDeviceIds.forEach(deviceId -> {
verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId);
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(deviceId));
});
verifyNoMoreInteractions(clientPresenceManager);
}

View File

@@ -28,6 +28,7 @@ import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.servlet.DispatcherType;
@@ -47,6 +48,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -74,6 +76,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
private static final PubSubClientEventManager PUBSUB_CLIENT_PRESENCE = mock(PubSubClientEventManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
@@ -122,9 +125,9 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE, PUBSUB_CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
@@ -215,6 +218,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
}
@Test
@@ -231,6 +238,10 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
verify(PUBSUB_CLIENT_PRESENCE, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(PUBSUB_CLIENT_PRESENCE);
}
@ParameterizedTest

View File

@@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -47,6 +48,7 @@ class RegistrationLockVerificationManagerTest {
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock(
ExternalServiceCredentialsGenerator.class);
private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock(
@@ -56,7 +58,7 @@ class RegistrationLockVerificationManagerTest {
private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
accountsManager, clientPresenceManager, pubSubClientEventManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
private final RateLimiter pinLimiter = mock(RateLimiter.class);
@@ -108,6 +110,7 @@ class RegistrationLockVerificationManagerTest {
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
}
verify(clientPresenceManager).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(Device.PRIMARY_ID));
try {
verify(pushNotificationManager).sendAttemptLoginNotification(any(), eq("failedRegistrationLock"));
} catch (NotPushRegisteredException npre) {}
@@ -131,6 +134,7 @@ class RegistrationLockVerificationManagerTest {
} catch (NotPushRegisteredException npre) {}
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
});
}
};
@@ -169,6 +173,7 @@ class RegistrationLockVerificationManagerTest {
verify(account, never()).lockAuthTokenHash();
verify(registrationRecoveryPasswordsManager, never()).removeForNumber(account.getNumber());
verify(clientPresenceManager, never()).disconnectAllPresences(account.getUuid(), List.of(Device.PRIMARY_ID));
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
}
static Stream<Arguments> testSuccess() {

View File

@@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@@ -111,6 +112,7 @@ class DeviceControllerTest {
private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class);
private static final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private static final PubSubClientEventManager pubSubClientEventManager = mock(PubSubClientEventManager.class);
private static final Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now();
@@ -131,7 +133,8 @@ class DeviceControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager,
pubSubClientEventManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController)
.build();

View File

@@ -21,6 +21,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -39,6 +40,7 @@ class MessageSenderTest {
private MessageProtos.Envelope message;
private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private MessagesManager messagesManager;
private PushNotificationManager pushNotificationManager;
private MessageSender messageSender;
@@ -54,9 +56,14 @@ class MessageSenderTest {
message = generateRandomMessage();
clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
messagesManager = mock(MessagesManager.class);
pushNotificationManager = mock(PushNotificationManager.class);
messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager);
when(pubSubClientEventManager.handleNewMessageAvailable(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(true));
messageSender = new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(device.getId()).thenReturn(DEVICE_ID);

View File

@@ -0,0 +1,337 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.IntStream;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class PubSubClientEventManagerTest {
private PubSubClientEventManager localPresenceManager;
private PubSubClientEventManager remotePresenceManager;
private static ExecutorService clientEventExecutor;
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static class ClientEventAdapter implements ClientEventListener {
@Override
public void handleNewMessageAvailable() {
}
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
}
}
@BeforeAll
static void setUpBeforeAll() {
clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
}
@BeforeEach
void setUp() {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
localPresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
remotePresenceManager = new PubSubClientEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor, experimentEnrollmentManager);
localPresenceManager.start();
remotePresenceManager.start();
}
@AfterEach
void tearDown() {
localPresenceManager.stop();
remotePresenceManager.stop();
}
@AfterAll
static void tearDownAfterAll() {
clientEventExecutor.shutdown();
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void handleClientConnected(final boolean displaceRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (firstListenerDisplaced) {
firstListenerDisplaced.set(true);
firstListenerConnectedElsewhere.set(connectedElsewhere);
firstListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
assertFalse(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
final PubSubClientEventManager displacingManager =
displaceRemotely ? remotePresenceManager : localPresenceManager;
displacingManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
secondListenerDisplaced.set(true);
}
}).toCompletableFuture().join();
synchronized (firstListenerDisplaced) {
while (!firstListenerDisplaced.get()) {
firstListenerDisplaced.wait();
}
}
assertTrue(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
assertTrue(firstListenerConnectedElsewhere.get());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void handleNewMessageAvailable(final boolean messageAvailableRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final AtomicBoolean messageReceived = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter() {
@Override
public void handleNewMessageAvailable() {
synchronized (messageReceived) {
messageReceived.set(true);
messageReceived.notifyAll();
}
}
}).toCompletableFuture().join();
final PubSubClientEventManager messagePresenceManager =
messageAvailableRemotely ? remotePresenceManager : localPresenceManager;
assertTrue(messagePresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
synchronized (messageReceived) {
while (!messageReceived.get()) {
messageReceived.wait();
}
}
assertTrue(messageReceived.get());
}
@Test
void handleClientDisconnected() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture().join();
assertTrue(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId).toCompletableFuture().join();
assertFalse(localPresenceManager.handleNewMessageAvailable(accountIdentifier, deviceId).toCompletableFuture().join());
}
@Test
void isLocallyPresent() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
final UUID connectionId =
localPresenceManager.handleClientConnected(accountIdentifier, deviceId, new ClientEventAdapter())
.toCompletableFuture()
.join();
assertTrue(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
localPresenceManager.handleClientDisconnected(accountIdentifier, deviceId, connectionId)
.toCompletableFuture()
.join();
assertFalse(localPresenceManager.isLocallyPresent(accountIdentifier, deviceId));
assertFalse(remotePresenceManager.isLocallyPresent(accountIdentifier, deviceId));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void requestDisconnection(final boolean requestDisconnectionRemotely) throws InterruptedException {
final UUID accountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID;
final byte secondDeviceId = firstDeviceId + 1;
final AtomicBoolean firstListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean secondListenerDisplaced = new AtomicBoolean(false);
final AtomicBoolean firstListenerConnectedElsewhere = new AtomicBoolean(false);
localPresenceManager.handleClientConnected(accountIdentifier, firstDeviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (firstListenerDisplaced) {
firstListenerDisplaced.set(true);
firstListenerConnectedElsewhere.set(connectedElsewhere);
firstListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
localPresenceManager.handleClientConnected(accountIdentifier, secondDeviceId, new ClientEventAdapter() {
@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
synchronized (secondListenerDisplaced) {
secondListenerDisplaced.set(true);
secondListenerDisplaced.notifyAll();
}
}
}).toCompletableFuture().join();
assertFalse(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
final PubSubClientEventManager displacingManager =
requestDisconnectionRemotely ? remotePresenceManager : localPresenceManager;
displacingManager.requestDisconnection(accountIdentifier, List.of(firstDeviceId)).toCompletableFuture().join();
synchronized (firstListenerDisplaced) {
while (!firstListenerDisplaced.get()) {
firstListenerDisplaced.wait();
}
}
assertTrue(firstListenerDisplaced.get());
assertFalse(secondListenerDisplaced.get());
assertFalse(firstListenerConnectedElsewhere.get());
}
@Test
void resubscribe() {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
@SuppressWarnings("unchecked") final RedisClusterPubSubCommands<byte[], byte[]> pubSubCommands =
mock(RedisClusterPubSubCommands.class);
@SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands<byte[], byte[]> pubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null));
final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder()
.binaryPubSubCommands(pubSubCommands)
.binaryPubSubAsyncCommands(pubSubAsyncCommands)
.build();
final PubSubClientEventManager presenceManager =
new PubSubClientEventManager(clusterClient, Runnable::run, experimentEnrollmentManager);
presenceManager.start();
final UUID firstAccountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID;
final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
final UUID secondAccountIdentifier;
final byte secondDeviceId = firstDeviceId + 1;
// Make sure that the two subscriptions wind up in different slots
{
UUID candidateIdentifier;
do {
candidateIdentifier = UUID.randomUUID();
} while (SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(candidateIdentifier, secondDeviceId)) == firstSlot);
secondAccountIdentifier = candidateIdentifier;
}
presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
final String firstNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstBeforeNode = mock(RedisClusterNode.class);
when(firstBeforeNode.getNodeId()).thenReturn(firstNodeId);
when(firstBeforeNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT).boxed().toList());
final RedisClusterNode firstAfterNode = mock(RedisClusterNode.class);
when(firstAfterNode.getNodeId()).thenReturn(firstNodeId);
when(firstAfterNode.getSlots()).thenReturn(IntStream.range(0, SlotHash.SLOT_COUNT)
.filter(slot -> slot != secondSlot)
.boxed()
.toList());
final RedisClusterNode secondAfterNode = mock(RedisClusterNode.class);
when(secondAfterNode.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(secondAfterNode.getSlots()).thenReturn(List.of(secondSlot));
presenceManager.resubscribe(new ClusterTopologyChangedEvent(
List.of(firstBeforeNode),
List.of(firstAfterNode, secondAfterNode)));
verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId));
verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId));
}
}

View File

@@ -31,6 +31,7 @@ import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
@@ -46,7 +47,7 @@ class FaultTolerantPubSubClusterConnectionTest {
private TestPublisher<Event> eventPublisher;
private Runnable resubscribe;
private Consumer<ClusterTopologyChangedEvent> resubscribe;
private AtomicInteger resubscribeCounter;
private CountDownLatch resubscribeFailure;
@@ -93,7 +94,7 @@ class FaultTolerantPubSubClusterConnectionTest {
resubscribeCounter = new AtomicInteger();
resubscribe = () -> {
resubscribe = event -> {
try {
resubscribeCounter.incrementAndGet();
pubSubConnection.sync().nodes((ignored) -> true);
@@ -137,7 +138,7 @@ class FaultTolerantPubSubClusterConnectionTest {
void testFilterClusterTopologyChangeEvents() throws InterruptedException {
final CountDownLatch topologyEventLatch = new CountDownLatch(1);
faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(topologyEventLatch::countDown);
faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(event -> topologyEventLatch.countDown());
final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class);

View File

@@ -44,6 +44,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -152,6 +153,7 @@ public class AccountCreationDeletionIntegrationTest {
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,

View File

@@ -37,6 +37,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -67,6 +68,7 @@ class AccountsManagerChangeNumberIntegrationTest {
private KeysManager keysManager;
private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor;
@@ -119,6 +121,7 @@ class AccountsManagerChangeNumberIntegrationTest {
when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName());
@@ -147,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest {
secureStorageClient,
svr2Client,
clientPresenceManager,
pubSubClientEventManager,
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,
@@ -281,7 +285,8 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow());
verify(clientPresenceManager).disconnectPresence(existingAccountUuid, Device.PRIMARY_ID);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingAccountUuid);
verify(pubSubClientEventManager).requestDisconnection(existingAccountUuid);
assertEquals(Optional.of(existingAccountUuid), accountsManager.findRecentlyDeletedAccountIdentifier(originalNumber));
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondNumber));

View File

@@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@@ -134,6 +135,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
mock(Executor.class),

View File

@@ -15,6 +15,7 @@ import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -63,6 +64,7 @@ public class AccountsManagerDeviceTransferIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
mock(ExecutorService.class),

View File

@@ -80,6 +80,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -118,6 +119,7 @@ class AccountsManagerTest {
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private ClientPresenceManager clientPresenceManager;
private PubSubClientEventManager pubSubClientEventManager;
private ClientPublicKeysManager clientPublicKeysManager;
private Map<String, UUID> phoneNumberIdentifiersByE164;
@@ -153,6 +155,7 @@ class AccountsManagerTest {
messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
pubSubClientEventManager = mock(PubSubClientEventManager.class);
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
dynamicConfiguration = mock(DynamicConfiguration.class);
@@ -259,6 +262,7 @@ class AccountsManagerTest {
storageClient,
svr2Client,
clientPresenceManager,
pubSubClientEventManager,
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
mock(Executor.class),
@@ -799,6 +803,7 @@ class AccountsManagerTest {
verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId());
verify(clientPublicKeysManager).buildTransactWriteItemForDeletion(account.getUuid(), linkedDevice.getId());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId());
verify(pubSubClientEventManager).requestDisconnection(account.getUuid(), List.of(linkedDevice.getId()));
}
@Test
@@ -817,6 +822,7 @@ class AccountsManagerTest {
verify(messagesManager, never()).clear(any(), anyByte());
verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte());
verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte());
verify(pubSubClientEventManager, never()).requestDisconnection(any(), any());
}
@Test
@@ -886,6 +892,7 @@ class AccountsManagerTest {
verify(messagesManager, times(2)).clear(existingUuid);
verify(profilesManager, times(2)).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
verify(pubSubClientEventManager).requestDisconnection(existingUuid);
}
@Test

View File

@@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -146,6 +147,7 @@ class AccountsManagerUsernameIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
Executors.newSingleThreadExecutor(),

View File

@@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -152,6 +153,7 @@ public class AddRemoveDeviceIntegrationTest {
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
mock(PubSubClientEventManager.class),
registrationRecoveryPasswordsManager,
clientPublicKeysManager,
accountLockExecutor,

View File

@@ -14,8 +14,12 @@ import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands;
import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands;
import java.util.function.Consumer;
import java.util.function.Function;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
public class RedisClusterHelper {
@@ -30,7 +34,12 @@ public class RedisClusterHelper {
final RedisAdvancedClusterAsyncCommands<String, String> stringAsyncCommands,
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands,
final RedisAdvancedClusterAsyncCommands<byte[], byte[]> binaryAsyncCommands,
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
final RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands,
final RedisClusterPubSubCommands<String, String> stringPubSubCommands,
final RedisClusterPubSubAsyncCommands<String, String> stringAsyncPubSubCommands,
final RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands,
final RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryAsyncPubSubCommands) {
final FaultTolerantRedisClusterClient cluster = mock(FaultTolerantRedisClusterClient.class);
final StatefulRedisClusterConnection<String, String> stringConnection = mock(StatefulRedisClusterConnection.class);
final StatefulRedisClusterConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisClusterConnection.class);
@@ -59,6 +68,45 @@ public class RedisClusterHelper {
return null;
}).when(cluster).useBinaryCluster(any(Consumer.class));
final StatefulRedisClusterPubSubConnection<String, String> stringPubSubConnection =
mock(StatefulRedisClusterPubSubConnection.class);
final StatefulRedisClusterPubSubConnection<byte[], byte[]> binaryPubSubConnection =
mock(StatefulRedisClusterPubSubConnection.class);
final FaultTolerantPubSubClusterConnection<String, String> faultTolerantPubSubClusterConnection =
mock(FaultTolerantPubSubClusterConnection.class);
final FaultTolerantPubSubClusterConnection<byte[], byte[]> faultTolerantBinaryPubSubClusterConnection =
mock(FaultTolerantPubSubClusterConnection.class);
when(stringPubSubConnection.sync()).thenReturn(stringPubSubCommands);
when(stringPubSubConnection.async()).thenReturn(stringAsyncPubSubCommands);
when(binaryPubSubConnection.sync()).thenReturn(binaryPubSubCommands);
when(binaryPubSubConnection.async()).thenReturn(binaryAsyncPubSubCommands);
when(cluster.createPubSubConnection()).thenReturn(faultTolerantPubSubClusterConnection);
when(cluster.createBinaryPubSubConnection()).thenReturn(faultTolerantBinaryPubSubClusterConnection);
when(faultTolerantPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringPubSubConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringPubSubConnection);
return null;
}).when(faultTolerantPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
when(faultTolerantBinaryPubSubClusterConnection.withPubSubConnection(any(Function.class))).thenAnswer(
invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryPubSubConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryPubSubConnection);
return null;
}).when(faultTolerantBinaryPubSubClusterConnection).usePubSubConnection(any(Consumer.class));
return cluster;
}
@@ -77,6 +125,18 @@ public class RedisClusterHelper {
private RedisAdvancedClusterReactiveCommands<byte[], byte[]> binaryReactiveCommands =
mock(RedisAdvancedClusterReactiveCommands.class);
private RedisClusterPubSubCommands<String, String> stringPubSubCommands =
mock(RedisClusterPubSubCommands.class);
private RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands =
mock(RedisClusterPubSubCommands.class);
private RedisClusterPubSubAsyncCommands<String, String> stringPubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
private RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryPubSubAsyncCommands =
mock(RedisClusterPubSubAsyncCommands.class);
private Builder() {
}
@@ -107,9 +167,33 @@ public class RedisClusterHelper {
return this;
}
public Builder stringPubSubCommands(final RedisClusterPubSubCommands<String, String> stringPubSubCommands) {
this.stringPubSubCommands = stringPubSubCommands;
return this;
}
public Builder binaryPubSubCommands(final RedisClusterPubSubCommands<byte[], byte[]> binaryPubSubCommands) {
this.binaryPubSubCommands = binaryPubSubCommands;
return this;
}
public Builder stringPubSubAsyncCommands(
final RedisClusterPubSubAsyncCommands<String, String> stringPubSubAsyncCommands) {
this.stringPubSubAsyncCommands = stringPubSubAsyncCommands;
return this;
}
public Builder binaryPubSubAsyncCommands(
final RedisClusterPubSubAsyncCommands<byte[], byte[]> binaryPubSubAsyncCommands) {
this.binaryPubSubAsyncCommands = binaryPubSubAsyncCommands;
return this;
}
public FaultTolerantRedisClusterClient build() {
return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands, binaryAsyncCommands,
binaryReactiveCommands);
return RedisClusterHelper.buildMockRedisCluster(stringCommands, stringAsyncCommands, binaryCommands,
binaryAsyncCommands,
binaryReactiveCommands, stringPubSubCommands, stringPubSubAsyncCommands, binaryPubSubCommands,
binaryPubSubAsyncCommands);
}
}

View File

@@ -5,17 +5,199 @@
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
class RedisClusterUtilTest {
@Test
void testGetMinimalHashTag() {
for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot)));
}
@Test
void testGetMinimalHashTag() {
for (int slot = 0; slot < SlotHash.SLOT_COUNT; slot++) {
assertEquals(slot, SlotHash.getSlot(RedisClusterUtil.getMinimalHashTag(slot)));
}
}
@ParameterizedTest
@MethodSource
void getChangedSlots(final ClusterTopologyChangedEvent event, final boolean[] expectedSlotsChanged) {
assertArrayEquals(expectedSlotsChanged, RedisClusterUtil.getChangedSlots(event));
}
private static List<Arguments> getChangedSlots() {
final List<Arguments> arguments = new ArrayList<>();
// Slot moved from one node to another
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore),
List.of(firstNodeAfter, secondNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
slotsChanged[8191] = true;
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// New node added to cluster
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 12288));
final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
when(thirdNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(12288, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore),
List.of(firstNodeAfter, secondNodeAfter, thirdNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
for (int slot = 12288; slot < 16384; slot++) {
slotsChanged[slot] = true;
}
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// Node removed from cluster
{
final String firstNodeId = UUID.randomUUID().toString();
final String secondNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(firstNodeId);
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 12288));
final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
when(thirdNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(12288, 16384));
final RedisClusterNode firstNodeAfter = mock(RedisClusterNode.class);
when(firstNodeAfter.getNodeId()).thenReturn(firstNodeId);
when(firstNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8192));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(8192, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
List.of(firstNodeAfter, secondNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
for (int slot = 12288; slot < 16384; slot++) {
slotsChanged[slot] = true;
}
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
// Node added, node removed, and slot moved
// Node removed from cluster
{
final String secondNodeId = UUID.randomUUID().toString();
final String thirdNodeId = UUID.randomUUID().toString();
final RedisClusterNode firstNodeBefore = mock(RedisClusterNode.class);
when(firstNodeBefore.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(firstNodeBefore.getSlots()).thenReturn(getSlotRange(0, 1));
final RedisClusterNode secondNodeBefore = mock(RedisClusterNode.class);
when(secondNodeBefore.getNodeId()).thenReturn(secondNodeId);
when(secondNodeBefore.getSlots()).thenReturn(getSlotRange(1, 8192));
final RedisClusterNode thirdNodeBefore = mock(RedisClusterNode.class);
when(thirdNodeBefore.getNodeId()).thenReturn(thirdNodeId);
when(thirdNodeBefore.getSlots()).thenReturn(getSlotRange(8192, 16384));
final RedisClusterNode secondNodeAfter = mock(RedisClusterNode.class);
when(secondNodeAfter.getNodeId()).thenReturn(secondNodeId);
when(secondNodeAfter.getSlots()).thenReturn(getSlotRange(0, 8191));
final RedisClusterNode thirdNodeAfter = mock(RedisClusterNode.class);
when(thirdNodeAfter.getNodeId()).thenReturn(thirdNodeId);
when(thirdNodeAfter.getSlots()).thenReturn(getSlotRange(8191, 16383));
final RedisClusterNode fourthNodeAfter = mock(RedisClusterNode.class);
when(fourthNodeAfter.getNodeId()).thenReturn(UUID.randomUUID().toString());
when(fourthNodeAfter.getSlots()).thenReturn(getSlotRange(16383, 16384));
final ClusterTopologyChangedEvent clusterTopologyChangedEvent = new ClusterTopologyChangedEvent(
List.of(firstNodeBefore, secondNodeBefore, thirdNodeBefore),
List.of(secondNodeAfter, thirdNodeAfter, fourthNodeAfter));
final boolean[] slotsChanged = new boolean[SlotHash.SLOT_COUNT];
slotsChanged[0] = true;
slotsChanged[8191] = true;
slotsChanged[16383] = true;
arguments.add(Arguments.of(clusterTopologyChangedEvent, slotsChanged));
}
return arguments;
}
private static List<Integer> getSlotRange(final int startInclusive, final int endExclusive) {
final List<Integer> slots = new ArrayList<>(endExclusive - startInclusive);
for (int i = startInclusive; i < endExclusive; i++) {
slots.add(i);
}
return slots;
}
}

View File

@@ -58,6 +58,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -124,8 +125,8 @@ class WebSocketConnectionTest {
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class));
mock(ClientPresenceManager.class), mock(PubSubClientEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))