Use MessageStream in WebSocketConnection

This commit is contained in:
Jon Chambers
2025-08-13 10:22:55 -04:00
committed by GitHub
parent 4c5dc118aa
commit 470e17963a
6 changed files with 745 additions and 1004 deletions

View File

@@ -0,0 +1,131 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Instant;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.session.WebSocketSessionContext;
class AuthenticatedConnectListenerTest {
private AccountsManager accountsManager;
private DisconnectionRequestManager disconnectionRequestManager;
private WebSocketConnection authenticatedWebSocketConnection;
private AuthenticatedConnectListener authenticatedConnectListener;
private Account authenticatedAccount;
private WebSocketClient webSocketClient;
private WebSocketSessionContext webSocketSessionContext;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
@BeforeEach
void setUpBeforeEach() {
accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
authenticatedWebSocketConnection = mock(WebSocketConnection.class);
authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager,
disconnectionRequestManager,
(_, _, _) -> authenticatedWebSocketConnection,
_ -> mock(OpenWebSocketCounter.class));
final Device device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
authenticatedAccount = mock(Account.class);
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER);
when(authenticatedAccount.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
webSocketClient = mock(WebSocketClient.class);
webSocketSessionContext = mock(WebSocketSessionContext.class);
when(webSocketSessionContext.getClient()).thenReturn(webSocketClient);
}
@Test
void onWebSocketConnectAuthenticated() {
when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class))
.thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount));
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
verify(webSocketSessionContext).addWebsocketClosedListener(any());
verify(authenticatedWebSocketConnection).start();
}
@Test
void onWebSocketConnectAuthenticatedAccountNotFound() {
when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class))
.thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.empty());
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(webSocketClient).close(eq(1011), anyString());
verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any());
verify(webSocketSessionContext, never()).addWebsocketClosedListener(any());
verify(authenticatedWebSocketConnection, never()).start();
}
@Test
void onWebSocketConnectAuthenticatedStartException() {
when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class))
.thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now()));
when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount));
doThrow(new RuntimeException()).when(authenticatedWebSocketConnection).start();
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection);
verify(webSocketSessionContext).addWebsocketClosedListener(any());
verify(authenticatedWebSocketConnection).start();
verify(webSocketClient).close(eq(1011), anyString());
}
@Test
void onWebSocketConnectUnauthenticated() {
authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext);
verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any());
verify(webSocketSessionContext, never()).addWebsocketClosedListener(any());
verify(authenticatedWebSocketConnection, never()).start();
}
}

View File

@@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -32,7 +33,6 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -42,7 +42,6 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@@ -90,16 +89,18 @@ class WebSocketConnectionIntegrationTest {
private Scheduler messageDeliveryScheduler;
private ClientReleaseManager clientReleaseManager;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private long serialTimestamp = System.currentTimeMillis();
@BeforeEach
void setUp() throws Exception {
sharedExecutorService = Executors.newSingleThreadExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
@@ -115,10 +116,14 @@ class WebSocketConnectionIntegrationTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
redisMessageAvailabilityManager.start();
}
@AfterEach
void tearDown() throws Exception {
redisMessageAvailabilityManager.stop();
sharedExecutorService.shutdown();
//noinspection ResultOfMethodCallIgnored
sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS);
@@ -143,7 +148,8 @@ class WebSocketConnectionIntegrationTest {
messageDeliveryScheduler,
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
mock(ExperimentEnrollmentManager.class)
);
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
@@ -171,36 +177,21 @@ class WebSocketConnectionIntegrationTest {
}
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
final AtomicBoolean queueCleared = new AtomicBoolean(false);
when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()))
.thenReturn(CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer(
(Answer<CompletableFuture<WebSocketResponseMessage>>) invocation -> {
synchronized (queueCleared) {
queueCleared.set(true);
queueCleared.notifyAll();
}
webSocketConnection.start();
return CompletableFuture.completedFuture(successResponse);
});
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor =
ArgumentCaptor.forClass(Optional.class);
webSocketConnection.processStoredMessages();
verify(webSocketClient, timeout(10_000))
.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
synchronized (queueCleared) {
while (!queueCleared.get()) {
queueCleared.wait();
}
}
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(
Optional.class);
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount))
.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>();
@@ -232,7 +223,8 @@ class WebSocketConnectionIntegrationTest {
messageDeliveryScheduler,
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class));
mock(ExperimentEnrollmentManager.class)
);
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@@ -264,10 +256,10 @@ class WebSocketConnectionIntegrationTest {
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.failedFuture(new IOException("Connection closed")));
webSocketConnection.processStoredMessages();
webSocketConnection.start();
//noinspection unchecked
ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
@@ -275,7 +267,7 @@ class WebSocketConnectionIntegrationTest {
eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = messageBodyCaptor.getAllValues().stream()
.map(Optional::get)
.map(Optional::orElseThrow)
.map(messageBytes -> {
try {
return Envelope.parseFrom(messageBytes);