Drop pub/sub operations from WebsocketConnection.

This commit is contained in:
Jon Chambers
2020-09-09 19:21:55 -04:00
committed by Jon Chambers
parent 4f2e06407b
commit 7e14a0bc30
8 changed files with 66 additions and 223 deletions

View File

@@ -59,6 +59,19 @@ public class ClientPresenceManagerTest extends AbstractRedisClusterTest {
assertTrue(clientPresenceManager.isPresent(accountUuid, deviceId));
}
@Test
public void testIsLocallyPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
getRedisCluster().useCluster(connection -> connection.sync().flushall());
assertTrue(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
}
@Test
public void testLocalDisplacement() {
final UUID accountUuid = UUID.randomUUID();

View File

@@ -14,7 +14,6 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -78,7 +77,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
webSocketConnection = new WebSocketConnection(mock(PushSender.class),
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messages, messagesCache, mock(PushLatencyManager.class)),
account,

View File

@@ -12,15 +12,12 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
@@ -68,11 +65,9 @@ public class WebSocketConnectionTest {
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
private static final PushSender pushSender = mock(PushSender.class);
private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
@@ -80,7 +75,7 @@ public class WebSocketConnectionTest {
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager, mock(ClientPresenceManager.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, apnFallbackManager, mock(ClientPresenceManager.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@@ -167,11 +162,10 @@ public class WebSocketConnectionTest {
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, "someid");
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
assertTrue(futures.size() == 3);
@@ -186,111 +180,15 @@ public class WebSocketConnectionTest {
verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), eq(2L), eq(false));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
}
@Test
public void testOnlineSend() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
WebsocketSender websocketSender = mock(WebsocketSender.class);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
.setSource("sender1")
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
.build();
Envelope secondMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("second".getBytes()))
.setSource("sender2")
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT)
.build();
List<OutgoingMessageEntity> pendingMessages = new LinkedList<>();
OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false);
when(device.getId()).thenReturn(2L);
when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52]));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(UUID.randomUUID());
final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<Device>() {{
add(sender1device);
}};
Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.<Account>empty());
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client, "anotherid");
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(firstMessage.toByteArray()))
.build().toByteArray());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(secondMessage.toByteArray()))
.build().toByteArray());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(futures.size(), 2);
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class));
verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
connection.stop();
verify(client).close(anyInt(), anyString());
}
@Test(timeout = 5_000L)
public void testOnlineSendViaKeyspaceNotification() throws Exception {
public void testOnlineSend() throws Exception {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
@@ -322,7 +220,7 @@ public class WebSocketConnectionTest {
// messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded
// future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize
// or wait for anything.
connection.onDispatchSubscribed("channel");
connection.start();
connection.handleNewMessagesAvailable();
@@ -349,11 +247,6 @@ public class WebSocketConnectionTest {
MessagesManager storedMessages = mock(MessagesManager.class);
WebsocketSender websocketSender = mock(WebsocketSender.class);
reset(websocketSender);
reset(pushSender);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
final Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
.setSource("sender1")
@@ -423,11 +316,10 @@ public class WebSocketConnectionTest {
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, "onemoreid");
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
@@ -440,9 +332,8 @@ public class WebSocketConnectionTest {
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verifyNoMoreInteractions(websocketSender);
verifyNoMoreInteractions(pushSender);
connection.onDispatchUnsubscribed(websocketAddress.serialize());
connection.stop();
verify(client).close(anyInt(), anyString());
}
@@ -450,7 +341,7 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -511,7 +402,7 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessagesMultiplePages() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -554,7 +445,7 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
@@ -583,7 +474,7 @@ public class WebSocketConnectionTest {
public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
@@ -609,15 +500,10 @@ public class WebSocketConnectionTest {
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final byte[] queryDbMessageBytes = PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build()
.toByteArray();
final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size());
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> {
connection.onDispatchMessage("channel", queryDbMessageBytes);
connection.handleNewMessagesAvailable();
sendLatch.countDown();
return CompletableFuture.completedFuture(successResponse);
@@ -635,7 +521,7 @@ public class WebSocketConnectionTest {
public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
@@ -667,7 +553,7 @@ public class WebSocketConnectionTest {
public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();