Clarify roles/responsibilities of components in the message-handling pathway

This commit is contained in:
Jon Chambers
2025-01-31 10:24:50 -05:00
committed by GitHub
parent 282bcf6f34
commit 48ada8e8ca
33 changed files with 1338 additions and 1199 deletions

View File

@@ -10,23 +10,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.apache.commons.lang3.RandomStringUtils;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -49,17 +51,21 @@ class MessageSenderTest {
@CartesianTest
void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean onlineMessage,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
final boolean expectPushNotificationAttempt = !clientPresent && !onlineMessage;
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = generateRandomMessage();
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
.setEphemeral(ephemeral)
.setUrgent(urgent)
.build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
@@ -72,18 +78,61 @@ class MessageSenderTest {
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
}
when(messagesManager.insert(eq(accountIdentifier), eq(deviceId), any())).thenReturn(clientPresent);
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, onlineMessage));
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
final MessageProtos.Envelope expectedMessage = onlineMessage
final MessageProtos.Envelope expectedMessage = ephemeral
? message.toBuilder().setEphemeral(true).build()
: message.toBuilder().build();
verify(messagesManager).insert(accountIdentifier, deviceId, expectedMessage);
verify(messagesManager).insert(accountIdentifier, Map.of(deviceId, expectedMessage));
if (expectPushNotificationAttempt) {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, expectedMessage.getUrgent());
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
} else {
verifyNoInteractions(pushNotificationManager);
}
}
@CartesianTest
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(device.getId()).thenReturn(deviceId);
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
} else {
doThrow(NotPushRegisteredException.class)
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
}
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
Collections.emptyMap(),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
.join());
if (expectPushNotificationAttempt) {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
} else {
verifyNoInteractions(pushNotificationManager);
}
@@ -123,14 +172,4 @@ class MessageSenderTest {
return arguments;
}
private MessageProtos.Envelope generateRandomMessage() {
return MessageProtos.Envelope.newBuilder()
.setClientTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(UUID.randomUUID().toString())
.build();
}
}

View File

@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -104,7 +105,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
verify(messageSender, never()).sendMessages(eq(account), any());
}
@Test
@@ -118,7 +119,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
verify(messageSender, never()).sendMessages(eq(account), any());
}
@Test
@@ -155,10 +156,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@@ -203,10 +209,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@@ -249,10 +260,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@@ -291,10 +307,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@@ -335,10 +356,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));

View File

@@ -84,7 +84,7 @@ class MessagePersisterIntegrationTest {
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService);
messageDeletionExecutorService, Clock.systemUTC());
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor();
@@ -143,7 +143,7 @@ class MessagePersisterIntegrationTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message);
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message).join();
expectedMessages.add(message);
}

View File

@@ -358,7 +358,7 @@ class MessagePersisterTest {
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join();
}
}

View File

@@ -40,7 +40,7 @@ class MessagesCacheGetItemsScriptTest {
.setServerGuid(serverGuid)
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@@ -41,7 +41,7 @@ class MessagesCacheInsertScriptTest {
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
@@ -50,11 +50,11 @@ class MessagesCacheInsertScriptTest {
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope2);
insertScript.executeAsync(destinationUuid, deviceId, envelope2);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
insertScript.execute(destinationUuid, deviceId, envelope1);
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
"Messages with same GUID should be deduplicated");
@@ -89,10 +89,10 @@ class MessagesCacheInsertScriptTest {
final MessagesCacheInsertScript insertScript =
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
assertFalse(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build()));
.build()).join());
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
@@ -100,9 +100,9 @@ class MessagesCacheInsertScriptTest {
pubSubClusterConnection.usePubSubConnection(connection ->
connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId)));
assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
assertTrue(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build()));
.build()).join());
}
}

View File

@@ -6,7 +6,9 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.lettuce.core.RedisCommandExecutionException;
import java.util.ArrayList;
@@ -14,8 +16,10 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import io.lettuce.core.RedisException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
@@ -39,8 +43,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
@@ -82,15 +86,17 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID));
insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)).join();
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class,
() -> insertMrmScript.execute(sharedMrmKey,
final CompletionException completionException = assertThrows(CompletionException.class,
() -> insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
Device.PRIMARY_ID)));
Device.PRIMARY_ID)).join());
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage());
assertInstanceOf(RedisException.class, completionException.getCause());
assertTrue(completionException.getCause().getMessage()
.contains(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS));
}
}

View File

@@ -34,7 +34,7 @@ class MessagesCacheRemoveByGuidScriptTest {
.setServerGuid(serverGuid.toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@@ -35,7 +35,7 @@ class MessagesCacheRemoveQueueScriptTest {
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@@ -41,8 +41,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
@@ -103,8 +102,8 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId));
insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)).join();
sharedMrmKeys.add(sharedMrmKey);
}

View File

@@ -122,7 +122,7 @@ class MessagesCacheTest {
void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender)));
generateRandomMessage(messageGuid, sealedSender))).join();
}
@Test
@@ -130,8 +130,8 @@ class MessagesCacheTest {
final UUID duplicateGuid = UUID.randomUUID();
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
.count()
@@ -149,7 +149,7 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
@@ -175,12 +175,12 @@ class MessagesCacheTest {
for (final MessageProtos.Envelope message : messagesToRemove) {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
message);
message).join();
}
for (final MessageProtos.Envelope message : messagesToPreserve) {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
message);
message).join();
}
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
@@ -197,7 +197,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
}
@@ -208,7 +208,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
}
@@ -223,7 +223,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertEquals(expectedOldestTimestamp,
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
expectedMessages.add(message);
@@ -248,7 +248,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
expectedMessages.add(message);
}
@@ -262,7 +262,7 @@ class MessagesCacheTest {
final UUID message1Guid = UUID.randomUUID();
final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender);
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1);
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1).join();
final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID,
1);
assertEquals(List.of(message1), get1);
@@ -272,7 +272,7 @@ class MessagesCacheTest {
final UUID message2Guid = UUID.randomUUID();
final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender);
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2);
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2).join();
assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1));
}
@@ -287,7 +287,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
expectedMessages.add(message);
}
@@ -295,7 +295,7 @@ class MessagesCacheTest {
final UUID ephemeralMessageGuid = UUID.randomUUID();
final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true)
.toBuilder().setEphemeral(true).build();
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage);
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage).join();
final Clock cacheClock;
if (expectStale) {
@@ -352,7 +352,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
}
}
@@ -372,7 +372,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
}
}
@@ -404,7 +404,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID();
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender));
generateRandomMessage(messageGuid, sealedSender)).join();
final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
@@ -427,7 +427,7 @@ class MessagesCacheTest {
final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
} else {
sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8);
}
@@ -440,7 +440,7 @@ class MessagesCacheTest {
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.clearContent()
.build();
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message).join();
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
@@ -487,13 +487,13 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
new AciServiceIdentifier(destinationUuid), true);
messagesCache.insert(messageGuid, destinationUuid, deviceId, message);
messagesCache.insert(messageGuid, destinationUuid, deviceId, message).join();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
} else {
sharedMrmDataKey = new byte[]{1};
}
@@ -505,7 +505,7 @@ class MessagesCacheTest {
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage);
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join();
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);

View File

@@ -7,22 +7,42 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
import org.whispersystems.textsecuregcm.util.TestClock;
import reactor.core.publisher.Mono;
class MessagesManagerTest {
@@ -31,8 +51,15 @@ class MessagesManagerTest {
private final MessagesCache messagesCache = mock(MessagesCache.class);
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager, Executors.newSingleThreadExecutor());
reportMessageManager, Executors.newSingleThreadExecutor(), CLOCK);
@BeforeEach
void setUp() {
when(messagesCache.insert(any(), any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(true));
}
@Test
void insert() {
@@ -43,7 +70,7 @@ class MessagesManagerTest {
final UUID destinationUuid = UUID.randomUUID();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message);
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, message));
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
@@ -51,11 +78,113 @@ class MessagesManagerTest {
.setSourceServiceId(destinationUuid.toString())
.build();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, syncMessage));
verifyNoMoreInteractions(reportMessageManager);
}
@Test
void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException {
final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier unresolvedAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Account singleDeviceAccount = mock(Account.class);
final Account multiDeviceAccount = mock(Account.class);
when(singleDeviceAccount.getIdentifier(IdentityType.ACI))
.thenReturn(singleDeviceAccountAciServiceIdentifier.uuid());
when(multiDeviceAccount.getIdentifier(IdentityType.ACI))
.thenReturn(multiDeviceAccountAciServiceIdentifier.uuid());
final byte[] multiRecipientMessageBytes = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(
new TestRecipient(singleDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]),
new TestRecipient(multiDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 2, new byte[48]),
new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]),
new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]),
new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48])
));
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes);
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = new HashMap<>();
multiRecipientMessage.getRecipients().forEach(((serviceId, recipient) -> {
if (serviceId.getRawUUID().equals(singleDeviceAccountAciServiceIdentifier.uuid()) ||
serviceId.getRawUUID().equals(singleDeviceAccountPniServiceIdentifier.uuid())) {
resolvedRecipients.put(recipient, singleDeviceAccount);
} else if (serviceId.getRawUUID().equals(multiDeviceAccountAciServiceIdentifier.uuid())) {
resolvedRecipients.put(recipient, multiDeviceAccount);
}
}));
final Map<Account, Map<Byte, Boolean>> expectedPresenceByAccountAndDeviceId = Map.of(
singleDeviceAccount, Map.of(Device.PRIMARY_ID, true),
multiDeviceAccount, Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
);
final Map<UUID, Map<Byte, Boolean>> presenceByAccountIdentifierAndDeviceId = Map.of(
singleDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, true),
multiDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
);
final byte[] sharedMrmKey = "shared-mrm-key".getBytes(StandardCharsets.UTF_8);
when(messagesCache.insertSharedMultiRecipientMessagePayload(multiRecipientMessage))
.thenReturn(CompletableFuture.completedFuture(sharedMrmKey));
when(messagesCache.insert(any(), any(), anyByte(), any()))
.thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(1);
final byte deviceId = invocation.getArgument(2);
return CompletableFuture.completedFuture(
presenceByAccountIdentifierAndDeviceId.getOrDefault(accountIdentifier, Collections.emptyMap())
.getOrDefault(deviceId, false));
});
final long clientTimestamp = System.currentTimeMillis();
final boolean isStory = ThreadLocalRandom.current().nextBoolean();
final boolean isEphemeral = ThreadLocalRandom.current().nextBoolean();
final boolean isUrgent = ThreadLocalRandom.current().nextBoolean();
final Envelope prototypeExpectedMessage = Envelope.newBuilder()
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(clientTimestamp)
.setServerTimestamp(CLOCK.millis())
.setStory(isStory)
.setEphemeral(isEphemeral)
.setUrgent(isUrgent)
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build();
assertEquals(expectedPresenceByAccountAndDeviceId,
messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join());
verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache).insert(any(),
eq(multiDeviceAccountAciServiceIdentifier.uuid()),
eq((byte) (Device.PRIMARY_ID + 1)),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache, never()).insert(any(),
eq(unresolvedAccountAciServiceIdentifier.uuid()),
anyByte(),
any());
}
@ParameterizedTest
@CsvSource({
"false, false, false",

View File

@@ -29,6 +29,7 @@ class ReportMessageDynamoDbTest {
void setUp() {
this.reportMessageDynamoDb = new ReportMessageDynamoDb(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.REPORT_MESSAGES.tableName(),
Duration.ofDays(1));
}
@@ -44,8 +45,8 @@ class ReportMessageDynamoDbTest {
() -> assertFalse(reportMessageDynamoDb.remove(hash2))
);
reportMessageDynamoDb.store(hash1);
reportMessageDynamoDb.store(hash2);
reportMessageDynamoDb.store(hash1).join();
reportMessageDynamoDb.store(hash2).join();
assertAll("both hashes should be found",
() -> assertTrue(reportMessageDynamoDb.remove(hash1)),

View File

@@ -18,6 +18,7 @@ import static org.mockito.Mockito.when;
import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
@@ -68,8 +69,8 @@ class ReportMessageManagerTest {
verify(reportMessageDynamoDb).store(any());
doThrow(RuntimeException.class)
.when(reportMessageDynamoDb).store(any());
when(reportMessageDynamoDb.store(any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid));
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.nio.ByteBuffer;
import java.util.List;
public class MultiRecipientMessageHelper {
private MultiRecipientMessageHelper() {
}
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients) {
return generateMultiRecipientMessage(recipients, 32);
}
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients, final int sharedPayloadSize) {
if (sharedPayloadSize < 32) {
throw new IllegalArgumentException("Shared payload size must be at least 32 bytes");
}
final ByteBuffer buffer = ByteBuffer.allocate(payloadSize(recipients, sharedPayloadSize));
// first write the header
buffer.put((byte) 0x23); // version byte
// count varint
writeVarint(buffer, recipients.size());
recipients.forEach(recipient -> {
buffer.put(recipient.uuid().toFixedWidthByteArray());
assert recipient.deviceIds().length == recipient.registrationIds().length;
for (int i = 0; i < recipient.deviceIds().length; i++) {
final int hasMore = i == recipient.deviceIds().length - 1 ? 0 : 0x8000;
buffer.put(recipient.deviceIds()[i]); // device id (1 byte)
buffer.putShort((short) (recipient.registrationIds()[i] | hasMore)); // registration id (2 bytes)
}
buffer.put(recipient.perRecipientKeyMaterial()); // key material (48 bytes)
});
// now write the actual message body (empty for now)
writeVarint(buffer, sharedPayloadSize);
buffer.put(new byte[sharedPayloadSize]);
return buffer.array();
}
private static void writeVarint(final ByteBuffer buffer, long n) {
if (n < 0) {
throw new IllegalArgumentException();
}
while (n >= 0x80) {
buffer.put ((byte) (n & 0x7F | 0x80));
n >>= 7;
}
buffer.put((byte) (n & 0x7F));
}
private static int payloadSize(final List<TestRecipient> recipients, final int sharedPayloadSize) {
final int fixedBytesPerRecipient = 17 // Service identifier length
+ 48; // Per-recipient key material
final int bytesForDevices = 3 * recipients.stream()
.mapToInt(recipient -> recipient.deviceIds().length)
.sum();
return 1 // Version byte
+ varintLength(recipients.size())
+ (recipients.size() * fixedBytesPerRecipient)
+ bytesForDevices
+ varintLength(sharedPayloadSize)
+ sharedPayloadSize;
}
private static int varintLength(long n) {
int length = 0;
while (n >= 0x80) {
length += 1;
n >>= 7;
}
return length + 1;
}
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public record TestRecipient(ServiceIdentifier uuid,
byte[] deviceIds,
int[] registrationIds,
byte[] perRecipientKeyMaterial) {
public TestRecipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
this(uuid, new byte[]{deviceId}, new int[]{registrationId}, perRecipientKeyMaterial);
}
}

View File

@@ -132,7 +132,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
@@ -164,7 +164,7 @@ class WebSocketConnectionIntegrationTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -220,7 +220,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessagesClientClosed() {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
@@ -253,7 +253,7 @@ class WebSocketConnectionIntegrationTest {
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@@ -289,7 +289,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessagesSendFutureTimeout() {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
@@ -323,7 +323,7 @@ class WebSocketConnectionIntegrationTest {
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope);
}