Migrate MessagesDynamoDbRule to MessagesDynamoDbExtension

This commit is contained in:
Chris Eager
2021-09-14 14:54:22 -07:00
committed by Jon Chambers
parent 6a5d475198
commit 83e0a19561
6 changed files with 370 additions and 322 deletions

View File

@@ -5,9 +5,10 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.eq;
@@ -34,10 +35,10 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@@ -45,195 +46,208 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest {
class WebSocketConnectionIntegrationTest {
@Rule
public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule();
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build();
private ExecutorService executorService;
private MessagesDynamoDb messagesDynamoDb;
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private long serialTimestamp = System.currentTimeMillis();
private ExecutorService executorService;
private MessagesDynamoDb messagesDynamoDb;
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
private long serialTimestamp = System.currentTimeMillis();
executorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), executorService);
messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7));
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
webSocketClient = mock(WebSocketClient.class);
retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor();
@BeforeEach
void setUp() throws Exception {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
executorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService);
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
Duration.ofDays(7));
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
webSocketClient = mock(WebSocketClient.class);
retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor();
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor);
}
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
@After
@Override
public void tearDown() throws Exception {
executorService.shutdown();
executorService.awaitTermination(2, TimeUnit.SECONDS);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor);
}
retrySchedulingExecutor.shutdown();
retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS);
@AfterEach
void tearDown() throws Exception {
executorService.shutdown();
executorService.awaitTermination(2, TimeUnit.SECONDS);
super.tearDown();
}
retrySchedulingExecutor.shutdown();
retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS);
}
@Test(timeout = 15_000)
public void testProcessStoredMessages() throws InterruptedException {
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@Test
void testProcessStoredMessages() {
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
assertTimeout(Duration.ofSeconds(15), () -> {
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
}
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
final AtomicBoolean queueCleared = new AtomicBoolean(false);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope);
}
when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.completedFuture(successResponse));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
final AtomicBoolean queueCleared = new AtomicBoolean(false);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> {
when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer(
(Answer<CompletableFuture<WebSocketResponseMessage>>) invocation -> {
synchronized (queueCleared) {
queueCleared.set(true);
queueCleared.notifyAll();
queueCleared.set(true);
queueCleared.notifyAll();
}
return CompletableFuture.completedFuture(successResponse);
});
webSocketConnection.processStoredMessages();
synchronized (queueCleared) {
while (!queueCleared.get()) {
queueCleared.wait();
}
}
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(
Optional.class);
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>();
for (final Optional<byte[]> maybeMessageBody : messageBodyCaptor.getAllValues()) {
maybeMessageBody.ifPresent(messageBytes -> {
try {
sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes));
} catch (final InvalidProtocolBufferException e) {
fail("Could not parse sent message");
}
});
}
webSocketConnection.processStoredMessages();
assertEquals(expectedMessages, sentMessages);
});
}
synchronized (queueCleared) {
while (!queueCleared.get()) {
queueCleared.wait();
}
@Test
void testProcessStoredMessagesClientClosed() {
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
assertTimeout(Duration.ofSeconds(15), () -> {
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
@SuppressWarnings("unchecked")
final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
}
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>();
expectedMessages.add(envelope);
}
for (final Optional<byte[]> maybeMessageBody : messageBodyCaptor.getAllValues()) {
maybeMessageBody.ifPresent(messageBytes -> {
try {
sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes));
} catch (final InvalidProtocolBufferException e) {
fail("Could not parse sent message");
}
});
}
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.failedFuture(new IOException("Connection closed")));
assertEquals(expectedMessages, sentMessages);
}
webSocketConnection.processStoredMessages();
@Test(timeout = 15_000)
public void testProcessStoredMessagesClientClosed() {
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
//noinspection unchecked
ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope);
}
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed")));
webSocketConnection.processStoredMessages();
//noinspection unchecked
ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(),
eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = messageBodyCaptor.getAllValues().stream()
.map(Optional::get)
.map(messageBytes -> {
try {
return Envelope.parseFrom(messageBytes);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
try {
return Envelope.parseFrom(messageBytes);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
assertTrue(expectedMessages.containsAll(sentMessages));
});
}
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {