Lifecycle management for Account objects reused accross websocket requests

This commit is contained in:
Ravi Khadiwala
2024-02-06 16:59:42 -06:00
committed by ravi-signal
parent 29ef3f0b41
commit 26ffa19f36
38 changed files with 1317 additions and 457 deletions

View File

@@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator = mock(AccountAuthenticator.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class)))));
.thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
@@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest {
@Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception {
final boolean expectInvalid) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) {
@@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest {
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator);
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
upgradeRequest);
final ReusableAuth<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(upgradeRequest);
assertEquals(expectAccount, result.getUser().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented());
assertEquals(expectAccount, result.ref().isPresent());
assertEquals(expectInvalid, result.invalidCredentialsProvided());
}
private static Stream<Arguments> testAuthenticate() {
@@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest {
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of(
// if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true),
Arguments.of(headerWithValidAuth, Map.of(), true, false),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true),
Arguments.of(null, paramsMapWithValidAuth, true, false),
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false)
);

View File

@@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly

View File

@@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@@ -101,7 +101,7 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
auth = new AuthenticatedAccount(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@@ -118,18 +118,19 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
ReusableAuth<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@@ -144,8 +145,8 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
assertFalse(account.credentialsPresented());
assertFalse(account.ref().isPresent());
assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(