Revert "Revert keyspace delivery for all messages"

This reverts commit 4dc49604b6.
This commit is contained in:
Jon Chambers
2020-09-14 15:57:06 -04:00
committed by Jon Chambers
parent 8016e84bc7
commit 62c31eb202
7 changed files with 427 additions and 47 deletions

View File

@@ -257,7 +257,7 @@ public class MessageControllerTest {
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/")
@@ -294,7 +294,7 @@ public class MessageControllerTest {
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
Response response =
resources.getJerseyTest().target("/v1/messages/")

View File

@@ -1,6 +1,7 @@
package org.whispersystems.textsecuregcm.tests.websocket;
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
@@ -21,16 +22,13 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
@@ -39,11 +37,25 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
public class WebSocketConnectionTest {
@@ -138,7 +150,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent))
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -225,7 +237,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent))
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -274,6 +286,64 @@ public class WebSocketConnectionTest {
verify(client).close(anyInt(), anyString());
}
@Test(timeout = 5_000L)
public void testOnlineSendViaKeyspaceNotification() throws Exception {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final AtomicInteger sendCounter = new AtomicInteger(0);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> {
synchronized (sendCounter) {
sendCounter.incrementAndGet();
sendCounter.notifyAll();
}
return CompletableFuture.completedFuture(successResponse);
});
// This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of
// messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded
// future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize
// or wait for anything.
connection.onDispatchSubscribed("channel");
connection.handleNewMessagesAvailable();
synchronized (sendCounter) {
while (sendCounter.get() < 1) {
sendCounter.wait();
}
}
connection.handleNewMessagesAvailable();
synchronized (sendCounter) {
while (sendCounter.get() < 2) {
sendCounter.wait();
}
}
verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
}
@Test
public void testPendingSend() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
@@ -336,7 +406,7 @@ public class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent))
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -376,6 +446,251 @@ public class WebSocketConnectionTest {
verify(client).close(anyInt(), anyString());
}
@Test(timeout = 5000L)
public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
threadWaiting.notifyAll();
}
synchronized (returnMessageList) {
while (!returnMessageList.get()) {
returnMessageList.wait();
}
}
return new OutgoingMessageEntityList(Collections.emptyList(), false);
});
final Thread[] threads = new Thread[10];
final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
connection.processStoredMessages();
unblockedThreadsLatch.countDown();
});
threads[i].start();
}
unblockedThreadsLatch.await();
synchronized (threadWaiting) {
while (!threadWaiting.get()) {
threadWaiting.wait();
}
}
synchronized (returnMessageList) {
returnMessageList.set(true);
returnMessageList.notifyAll();
}
for (final Thread thread : threads) {
thread.join();
}
verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false));
}
@Test(timeout = 5000L)
public void testProcessStoredMessagesMultiplePages() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(firstPage)
.thenReturn(secondPage);
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size());
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> {
sendLatch.countDown();
return CompletableFuture.completedFuture(successResponse);
});
connection.processStoredMessages();
sendLatch.await();
verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
@Test
public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
// This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to
// CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the
// whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for
// anything.
connection.processStoredMessages();
connection.processStoredMessages();
verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
@Test(timeout = 5000L)
public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(firstPage)
.thenReturn(secondPage)
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final byte[] queryDbMessageBytes = PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build()
.toByteArray();
final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size());
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> {
connection.onDispatchMessage("channel", queryDbMessageBytes);
sendLatch.countDown();
return CompletableFuture.completedFuture(successResponse);
});
connection.processStoredMessages();
sendLatch.await();
verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
@Test
public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
// This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to
// CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the
// whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false);
connection.handleNewMessagesAvailable();
verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true);
}
@Test
public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency");
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
// This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to
// CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the
// whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for
// anything.
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false);
}
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,