Don't cache authenticated accounts in memory

This commit is contained in:
Jon Chambers
2025-06-23 08:40:05 -05:00
committed by GitHub
parent 9dfe51eac4
commit c952baa672
86 changed files with 961 additions and 2264 deletions

View File

@@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@@ -70,14 +69,12 @@ class WebSocketAccountAuthenticatorTest {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
} else {
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).isPresent());
}
}

View File

@@ -43,11 +43,11 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@@ -111,7 +111,7 @@ class WebSocketConnectionIntegrationTest {
clientReleaseManager = mock(ClientReleaseManager.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
}
@@ -137,7 +137,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@@ -159,14 +160,14 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -226,7 +227,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@@ -250,13 +252,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -296,7 +298,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly
scheduledExecutorService,
@@ -321,13 +324,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}

View File

@@ -56,6 +56,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@@ -67,9 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@@ -90,7 +89,6 @@ class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedDevice auth;
private UpgradeRequest upgradeRequest;
private MessagesManager messagesManager;
private ReceiptSender receiptSender;
@@ -104,7 +102,6 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedDevice(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@@ -122,8 +119,8 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class),
@@ -133,9 +130,9 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedDevice(account, device)));
ReusableAuth<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.ref().orElse(null));
Optional<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@@ -150,7 +147,7 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent());
assertFalse(account.isPresent());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(
@@ -174,7 +171,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@@ -191,7 +188,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(outgoingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -237,7 +234,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -320,7 +317,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@@ -337,7 +334,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(pendingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -364,7 +361,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getClientTimestamp()));
connection.stop();
@@ -377,7 +374,7 @@ class WebSocketConnectionTest {
final WebSocketConnection connection = webSocketConnection(client);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -385,7 +382,7 @@ class WebSocketConnectionTest {
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(
messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenAnswer(invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@@ -442,7 +439,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -490,7 +487,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -573,7 +570,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -581,7 +578,7 @@ class WebSocketConnectionTest {
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(messages))
.thenReturn(Flux.empty());
@@ -629,7 +626,7 @@ class WebSocketConnectionTest {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), account, device, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class));
}
@@ -642,7 +639,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -669,7 +666,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -725,7 +722,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -741,11 +738,11 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
connection.handleNewMessageAvailable();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true);
}
@Test
@@ -756,7 +753,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@@ -773,7 +770,7 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
}
@Test
@@ -783,9 +780,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
@@ -812,9 +809,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
final WebSocketClient client = mock(WebSocketClient.class);
@@ -835,7 +832,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final int totalMessages = 1000;
@@ -884,7 +881,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final AtomicBoolean canceled = new AtomicBoolean();