mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-20 20:28:06 +01:00
Clarify roles/responsibilities of components in the message-handling pathway
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -10,23 +10,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.anyByte;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
@@ -49,17 +51,21 @@ class MessageSenderTest {
|
||||
|
||||
@CartesianTest
|
||||
void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean onlineMessage,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !onlineMessage;
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
final MessageProtos.Envelope message = generateRandomMessage();
|
||||
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
|
||||
.setEphemeral(ephemeral)
|
||||
.setUrgent(urgent)
|
||||
.build();
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
@@ -72,18 +78,61 @@ class MessageSenderTest {
|
||||
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
|
||||
}
|
||||
|
||||
when(messagesManager.insert(eq(accountIdentifier), eq(deviceId), any())).thenReturn(clientPresent);
|
||||
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, onlineMessage));
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
|
||||
|
||||
final MessageProtos.Envelope expectedMessage = onlineMessage
|
||||
final MessageProtos.Envelope expectedMessage = ephemeral
|
||||
? message.toBuilder().setEphemeral(true).build()
|
||||
: message.toBuilder().build();
|
||||
|
||||
verify(messagesManager).insert(accountIdentifier, deviceId, expectedMessage);
|
||||
verify(messagesManager).insert(accountIdentifier, Map.of(deviceId, expectedMessage));
|
||||
|
||||
if (expectPushNotificationAttempt) {
|
||||
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, expectedMessage.getUrgent());
|
||||
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
|
||||
} else {
|
||||
verifyNoInteractions(pushNotificationManager);
|
||||
}
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
} else {
|
||||
doThrow(NotPushRegisteredException.class)
|
||||
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
|
||||
}
|
||||
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
|
||||
Collections.emptyMap(),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
.join());
|
||||
|
||||
if (expectPushNotificationAttempt) {
|
||||
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
|
||||
} else {
|
||||
verifyNoInteractions(pushNotificationManager);
|
||||
}
|
||||
@@ -123,14 +172,4 @@ class MessageSenderTest {
|
||||
|
||||
return arguments;
|
||||
}
|
||||
|
||||
private MessageProtos.Envelope generateRandomMessage() {
|
||||
return MessageProtos.Envelope.newBuilder()
|
||||
.setClientTimestamp(System.currentTimeMillis())
|
||||
.setServerTimestamp(System.currentTimeMillis())
|
||||
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
|
||||
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -104,7 +105,7 @@ public class ChangeNumberManagerTest {
|
||||
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
|
||||
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
|
||||
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -118,7 +119,7 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
|
||||
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -155,10 +156,15 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
|
||||
|
||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||
|
||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||
@@ -203,10 +209,15 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||
|
||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||
|
||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||
@@ -249,10 +260,15 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||
|
||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||
|
||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||
@@ -291,10 +307,15 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
|
||||
|
||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||
|
||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||
@@ -335,10 +356,15 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||
|
||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||
|
||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||
|
||||
@@ -84,7 +84,7 @@ class MessagePersisterIntegrationTest {
|
||||
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
|
||||
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
|
||||
messageDeletionExecutorService);
|
||||
messageDeletionExecutorService, Clock.systemUTC());
|
||||
|
||||
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor();
|
||||
@@ -143,7 +143,7 @@ class MessagePersisterIntegrationTest {
|
||||
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
|
||||
|
||||
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message).join();
|
||||
expectedMessages.add(message);
|
||||
}
|
||||
|
||||
|
||||
@@ -358,7 +358,7 @@ class MessagePersisterTest {
|
||||
.setServerGuid(messageGuid.toString())
|
||||
.build();
|
||||
|
||||
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
|
||||
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class MessagesCacheGetItemsScriptTest {
|
||||
.setServerGuid(serverGuid)
|
||||
.build();
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||
|
||||
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
@@ -41,7 +41,7 @@ class MessagesCacheInsertScriptTest {
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build();
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||
|
||||
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
|
||||
|
||||
@@ -50,11 +50,11 @@ class MessagesCacheInsertScriptTest {
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build();
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope2);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope2);
|
||||
|
||||
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||
|
||||
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
|
||||
"Messages with same GUID should be deduplicated");
|
||||
@@ -89,10 +89,10 @@ class MessagesCacheInsertScriptTest {
|
||||
final MessagesCacheInsertScript insertScript =
|
||||
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||
assertFalse(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||
.setServerTimestamp(Instant.now().getEpochSecond())
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build()));
|
||||
.build()).join());
|
||||
|
||||
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
|
||||
@@ -100,9 +100,9 @@ class MessagesCacheInsertScriptTest {
|
||||
pubSubClusterConnection.usePubSubConnection(connection ->
|
||||
connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId)));
|
||||
|
||||
assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||
assertTrue(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||
.setServerTimestamp(Instant.now().getEpochSecond())
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build()));
|
||||
.build()).join());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.lettuce.core.RedisCommandExecutionException;
|
||||
import java.util.ArrayList;
|
||||
@@ -14,8 +16,10 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
import io.lettuce.core.RedisException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
@@ -39,8 +43,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||
insertMrmScript.execute(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(destinations));
|
||||
insertMrmScript.executeAsync(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
|
||||
|
||||
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
|
||||
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
||||
@@ -82,15 +86,17 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||
insertMrmScript.execute(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID));
|
||||
insertMrmScript.executeAsync(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)).join();
|
||||
|
||||
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class,
|
||||
() -> insertMrmScript.execute(sharedMrmKey,
|
||||
final CompletionException completionException = assertThrows(CompletionException.class,
|
||||
() -> insertMrmScript.executeAsync(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
|
||||
Device.PRIMARY_ID)));
|
||||
Device.PRIMARY_ID)).join());
|
||||
|
||||
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage());
|
||||
assertInstanceOf(RedisException.class, completionException.getCause());
|
||||
assertTrue(completionException.getCause().getMessage()
|
||||
.contains(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ class MessagesCacheRemoveByGuidScriptTest {
|
||||
.setServerGuid(serverGuid.toString())
|
||||
.build();
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||
|
||||
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
@@ -35,7 +35,7 @@ class MessagesCacheRemoveQueueScriptTest {
|
||||
.setServerGuid(UUID.randomUUID().toString())
|
||||
.build();
|
||||
|
||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
||||
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||
|
||||
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
@@ -41,8 +41,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||
insertMrmScript.execute(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(destinations));
|
||||
insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
|
||||
|
||||
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
@@ -103,8 +102,8 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
|
||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||
|
||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||
insertMrmScript.execute(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId));
|
||||
insertMrmScript.executeAsync(sharedMrmKey,
|
||||
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)).join();
|
||||
|
||||
sharedMrmKeys.add(sharedMrmKey);
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ class MessagesCacheTest {
|
||||
void testInsert(final boolean sealedSender) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
generateRandomMessage(messageGuid, sealedSender)));
|
||||
generateRandomMessage(messageGuid, sealedSender))).join();
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -130,8 +130,8 @@ class MessagesCacheTest {
|
||||
final UUID duplicateGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
|
||||
|
||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
|
||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
|
||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
|
||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
|
||||
|
||||
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
|
||||
.count()
|
||||
@@ -149,7 +149,7 @@ class MessagesCacheTest {
|
||||
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
|
||||
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
|
||||
|
||||
@@ -175,12 +175,12 @@ class MessagesCacheTest {
|
||||
|
||||
for (final MessageProtos.Envelope message : messagesToRemove) {
|
||||
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
message);
|
||||
message).join();
|
||||
}
|
||||
|
||||
for (final MessageProtos.Envelope message : messagesToPreserve) {
|
||||
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
message);
|
||||
message).join();
|
||||
}
|
||||
|
||||
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
@@ -197,7 +197,7 @@ class MessagesCacheTest {
|
||||
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
|
||||
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
|
||||
}
|
||||
@@ -208,7 +208,7 @@ class MessagesCacheTest {
|
||||
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
|
||||
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
|
||||
}
|
||||
@@ -223,7 +223,7 @@ class MessagesCacheTest {
|
||||
for (int i = 0; i < messageCount; i++) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
assertEquals(expectedOldestTimestamp,
|
||||
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
|
||||
expectedMessages.add(message);
|
||||
@@ -248,7 +248,7 @@ class MessagesCacheTest {
|
||||
for (int i = 0; i < messageCount; i++) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
expectedMessages.add(message);
|
||||
}
|
||||
|
||||
@@ -262,7 +262,7 @@ class MessagesCacheTest {
|
||||
|
||||
final UUID message1Guid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender);
|
||||
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1);
|
||||
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1).join();
|
||||
final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
1);
|
||||
assertEquals(List.of(message1), get1);
|
||||
@@ -272,7 +272,7 @@ class MessagesCacheTest {
|
||||
final UUID message2Guid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender);
|
||||
|
||||
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2);
|
||||
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2).join();
|
||||
|
||||
assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1));
|
||||
}
|
||||
@@ -287,7 +287,7 @@ class MessagesCacheTest {
|
||||
for (int i = 0; i < messageCount; i++) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||
|
||||
expectedMessages.add(message);
|
||||
}
|
||||
@@ -295,7 +295,7 @@ class MessagesCacheTest {
|
||||
final UUID ephemeralMessageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true)
|
||||
.toBuilder().setEphemeral(true).build();
|
||||
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage);
|
||||
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage).join();
|
||||
|
||||
final Clock cacheClock;
|
||||
if (expectStale) {
|
||||
@@ -352,7 +352,7 @@ class MessagesCacheTest {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +372,7 @@ class MessagesCacheTest {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,7 +404,7 @@ class MessagesCacheTest {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
|
||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
generateRandomMessage(messageGuid, sealedSender));
|
||||
generateRandomMessage(messageGuid, sealedSender)).join();
|
||||
final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
|
||||
|
||||
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
|
||||
@@ -427,7 +427,7 @@ class MessagesCacheTest {
|
||||
|
||||
final byte[] sharedMrmDataKey;
|
||||
if (sharedMrmKeyPresent) {
|
||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
|
||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
|
||||
} else {
|
||||
sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
@@ -440,7 +440,7 @@ class MessagesCacheTest {
|
||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
||||
.clearContent()
|
||||
.build();
|
||||
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
|
||||
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message).join();
|
||||
|
||||
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
||||
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
|
||||
@@ -487,13 +487,13 @@ class MessagesCacheTest {
|
||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
|
||||
new AciServiceIdentifier(destinationUuid), true);
|
||||
|
||||
messagesCache.insert(messageGuid, destinationUuid, deviceId, message);
|
||||
messagesCache.insert(messageGuid, destinationUuid, deviceId, message).join();
|
||||
|
||||
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
|
||||
|
||||
final byte[] sharedMrmDataKey;
|
||||
if (sharedMrmKeyPresent) {
|
||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
|
||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
|
||||
} else {
|
||||
sharedMrmDataKey = new byte[]{1};
|
||||
}
|
||||
@@ -505,7 +505,7 @@ class MessagesCacheTest {
|
||||
.clearContent()
|
||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
||||
.build();
|
||||
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage);
|
||||
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join();
|
||||
|
||||
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);
|
||||
|
||||
|
||||
@@ -7,22 +7,42 @@ package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyByte;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
|
||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
class MessagesManagerTest {
|
||||
@@ -31,8 +51,15 @@ class MessagesManagerTest {
|
||||
private final MessagesCache messagesCache = mock(MessagesCache.class);
|
||||
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
|
||||
|
||||
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
|
||||
|
||||
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
|
||||
reportMessageManager, Executors.newSingleThreadExecutor());
|
||||
reportMessageManager, Executors.newSingleThreadExecutor(), CLOCK);
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
when(messagesCache.insert(any(), any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
void insert() {
|
||||
@@ -43,7 +70,7 @@ class MessagesManagerTest {
|
||||
|
||||
final UUID destinationUuid = UUID.randomUUID();
|
||||
|
||||
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message);
|
||||
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, message));
|
||||
|
||||
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
|
||||
|
||||
@@ -51,11 +78,113 @@ class MessagesManagerTest {
|
||||
.setSourceServiceId(destinationUuid.toString())
|
||||
.build();
|
||||
|
||||
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);
|
||||
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, syncMessage));
|
||||
|
||||
verifyNoMoreInteractions(reportMessageManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException {
|
||||
final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
|
||||
final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
final ServiceIdentifier unresolvedAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
|
||||
final Account singleDeviceAccount = mock(Account.class);
|
||||
final Account multiDeviceAccount = mock(Account.class);
|
||||
|
||||
when(singleDeviceAccount.getIdentifier(IdentityType.ACI))
|
||||
.thenReturn(singleDeviceAccountAciServiceIdentifier.uuid());
|
||||
|
||||
when(multiDeviceAccount.getIdentifier(IdentityType.ACI))
|
||||
.thenReturn(multiDeviceAccountAciServiceIdentifier.uuid());
|
||||
|
||||
final byte[] multiRecipientMessageBytes = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(
|
||||
new TestRecipient(singleDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]),
|
||||
new TestRecipient(multiDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 2, new byte[48]),
|
||||
new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]),
|
||||
new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]),
|
||||
new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48])
|
||||
));
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes);
|
||||
|
||||
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = new HashMap<>();
|
||||
|
||||
multiRecipientMessage.getRecipients().forEach(((serviceId, recipient) -> {
|
||||
if (serviceId.getRawUUID().equals(singleDeviceAccountAciServiceIdentifier.uuid()) ||
|
||||
serviceId.getRawUUID().equals(singleDeviceAccountPniServiceIdentifier.uuid())) {
|
||||
resolvedRecipients.put(recipient, singleDeviceAccount);
|
||||
} else if (serviceId.getRawUUID().equals(multiDeviceAccountAciServiceIdentifier.uuid())) {
|
||||
resolvedRecipients.put(recipient, multiDeviceAccount);
|
||||
}
|
||||
}));
|
||||
|
||||
final Map<Account, Map<Byte, Boolean>> expectedPresenceByAccountAndDeviceId = Map.of(
|
||||
singleDeviceAccount, Map.of(Device.PRIMARY_ID, true),
|
||||
multiDeviceAccount, Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
|
||||
);
|
||||
|
||||
final Map<UUID, Map<Byte, Boolean>> presenceByAccountIdentifierAndDeviceId = Map.of(
|
||||
singleDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, true),
|
||||
multiDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
|
||||
);
|
||||
|
||||
final byte[] sharedMrmKey = "shared-mrm-key".getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
when(messagesCache.insertSharedMultiRecipientMessagePayload(multiRecipientMessage))
|
||||
.thenReturn(CompletableFuture.completedFuture(sharedMrmKey));
|
||||
|
||||
when(messagesCache.insert(any(), any(), anyByte(), any()))
|
||||
.thenAnswer(invocation -> {
|
||||
final UUID accountIdentifier = invocation.getArgument(1);
|
||||
final byte deviceId = invocation.getArgument(2);
|
||||
|
||||
return CompletableFuture.completedFuture(
|
||||
presenceByAccountIdentifierAndDeviceId.getOrDefault(accountIdentifier, Collections.emptyMap())
|
||||
.getOrDefault(deviceId, false));
|
||||
});
|
||||
|
||||
final long clientTimestamp = System.currentTimeMillis();
|
||||
final boolean isStory = ThreadLocalRandom.current().nextBoolean();
|
||||
final boolean isEphemeral = ThreadLocalRandom.current().nextBoolean();
|
||||
final boolean isUrgent = ThreadLocalRandom.current().nextBoolean();
|
||||
|
||||
final Envelope prototypeExpectedMessage = Envelope.newBuilder()
|
||||
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
|
||||
.setClientTimestamp(clientTimestamp)
|
||||
.setServerTimestamp(CLOCK.millis())
|
||||
.setStory(isStory)
|
||||
.setEphemeral(isEphemeral)
|
||||
.setUrgent(isUrgent)
|
||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
|
||||
.build();
|
||||
|
||||
assertEquals(expectedPresenceByAccountAndDeviceId,
|
||||
messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join());
|
||||
|
||||
verify(messagesCache).insert(any(),
|
||||
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
|
||||
eq(Device.PRIMARY_ID),
|
||||
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
|
||||
|
||||
verify(messagesCache).insert(any(),
|
||||
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
|
||||
eq(Device.PRIMARY_ID),
|
||||
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build()));
|
||||
|
||||
verify(messagesCache).insert(any(),
|
||||
eq(multiDeviceAccountAciServiceIdentifier.uuid()),
|
||||
eq((byte) (Device.PRIMARY_ID + 1)),
|
||||
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
|
||||
|
||||
verify(messagesCache, never()).insert(any(),
|
||||
eq(unresolvedAccountAciServiceIdentifier.uuid()),
|
||||
anyByte(),
|
||||
any());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"false, false, false",
|
||||
|
||||
@@ -29,6 +29,7 @@ class ReportMessageDynamoDbTest {
|
||||
void setUp() {
|
||||
this.reportMessageDynamoDb = new ReportMessageDynamoDb(
|
||||
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
|
||||
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
|
||||
Tables.REPORT_MESSAGES.tableName(),
|
||||
Duration.ofDays(1));
|
||||
}
|
||||
@@ -44,8 +45,8 @@ class ReportMessageDynamoDbTest {
|
||||
() -> assertFalse(reportMessageDynamoDb.remove(hash2))
|
||||
);
|
||||
|
||||
reportMessageDynamoDb.store(hash1);
|
||||
reportMessageDynamoDb.store(hash2);
|
||||
reportMessageDynamoDb.store(hash1).join();
|
||||
reportMessageDynamoDb.store(hash2).join();
|
||||
|
||||
assertAll("both hashes should be found",
|
||||
() -> assertTrue(reportMessageDynamoDb.remove(hash1)),
|
||||
|
||||
@@ -18,6 +18,7 @@ import static org.mockito.Mockito.when;
|
||||
import java.time.Duration;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
@@ -68,8 +69,8 @@ class ReportMessageManagerTest {
|
||||
|
||||
verify(reportMessageDynamoDb).store(any());
|
||||
|
||||
doThrow(RuntimeException.class)
|
||||
.when(reportMessageDynamoDb).store(any());
|
||||
when(reportMessageDynamoDb.store(any()))
|
||||
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
|
||||
|
||||
assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.tests.util;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
|
||||
public class MultiRecipientMessageHelper {
|
||||
|
||||
private MultiRecipientMessageHelper() {
|
||||
}
|
||||
|
||||
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients) {
|
||||
return generateMultiRecipientMessage(recipients, 32);
|
||||
}
|
||||
|
||||
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients, final int sharedPayloadSize) {
|
||||
if (sharedPayloadSize < 32) {
|
||||
throw new IllegalArgumentException("Shared payload size must be at least 32 bytes");
|
||||
}
|
||||
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(payloadSize(recipients, sharedPayloadSize));
|
||||
|
||||
// first write the header
|
||||
buffer.put((byte) 0x23); // version byte
|
||||
|
||||
// count varint
|
||||
writeVarint(buffer, recipients.size());
|
||||
|
||||
recipients.forEach(recipient -> {
|
||||
buffer.put(recipient.uuid().toFixedWidthByteArray());
|
||||
|
||||
assert recipient.deviceIds().length == recipient.registrationIds().length;
|
||||
|
||||
for (int i = 0; i < recipient.deviceIds().length; i++) {
|
||||
final int hasMore = i == recipient.deviceIds().length - 1 ? 0 : 0x8000;
|
||||
buffer.put(recipient.deviceIds()[i]); // device id (1 byte)
|
||||
buffer.putShort((short) (recipient.registrationIds()[i] | hasMore)); // registration id (2 bytes)
|
||||
}
|
||||
|
||||
buffer.put(recipient.perRecipientKeyMaterial()); // key material (48 bytes)
|
||||
});
|
||||
|
||||
// now write the actual message body (empty for now)
|
||||
writeVarint(buffer, sharedPayloadSize);
|
||||
buffer.put(new byte[sharedPayloadSize]);
|
||||
|
||||
return buffer.array();
|
||||
}
|
||||
|
||||
private static void writeVarint(final ByteBuffer buffer, long n) {
|
||||
if (n < 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
while (n >= 0x80) {
|
||||
buffer.put ((byte) (n & 0x7F | 0x80));
|
||||
n >>= 7;
|
||||
}
|
||||
buffer.put((byte) (n & 0x7F));
|
||||
}
|
||||
|
||||
private static int payloadSize(final List<TestRecipient> recipients, final int sharedPayloadSize) {
|
||||
final int fixedBytesPerRecipient = 17 // Service identifier length
|
||||
+ 48; // Per-recipient key material
|
||||
|
||||
final int bytesForDevices = 3 * recipients.stream()
|
||||
.mapToInt(recipient -> recipient.deviceIds().length)
|
||||
.sum();
|
||||
|
||||
return 1 // Version byte
|
||||
+ varintLength(recipients.size())
|
||||
+ (recipients.size() * fixedBytesPerRecipient)
|
||||
+ bytesForDevices
|
||||
+ varintLength(sharedPayloadSize)
|
||||
+ sharedPayloadSize;
|
||||
}
|
||||
|
||||
private static int varintLength(long n) {
|
||||
int length = 0;
|
||||
|
||||
while (n >= 0x80) {
|
||||
length += 1;
|
||||
n >>= 7;
|
||||
}
|
||||
|
||||
return length + 1;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.tests.util;
|
||||
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
|
||||
public record TestRecipient(ServiceIdentifier uuid,
|
||||
byte[] deviceIds,
|
||||
int[] registrationIds,
|
||||
byte[] perRecipientKeyMaterial) {
|
||||
|
||||
public TestRecipient(ServiceIdentifier uuid,
|
||||
byte deviceId,
|
||||
int registrationId,
|
||||
byte[] perRecipientKeyMaterial) {
|
||||
|
||||
this(uuid, new byte[]{deviceId}, new int[]{registrationId}, perRecipientKeyMaterial);
|
||||
}
|
||||
}
|
||||
@@ -132,7 +132,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) {
|
||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||
new MessageMetrics(),
|
||||
mock(PushNotificationManager.class),
|
||||
mock(PushNotificationScheduler.class),
|
||||
@@ -164,7 +164,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||
expectedMessages.add(envelope);
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
void testProcessStoredMessagesClientClosed() {
|
||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||
new MessageMetrics(),
|
||||
mock(PushNotificationManager.class),
|
||||
mock(PushNotificationScheduler.class),
|
||||
@@ -253,7 +253,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
for (int i = 0; i < cachedMessageCount; i++) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||
|
||||
expectedMessages.add(envelope);
|
||||
}
|
||||
@@ -289,7 +289,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
void testProcessStoredMessagesSendFutureTimeout() {
|
||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||
new MessageMetrics(),
|
||||
mock(PushNotificationManager.class),
|
||||
mock(PushNotificationScheduler.class),
|
||||
@@ -323,7 +323,7 @@ class WebSocketConnectionIntegrationTest {
|
||||
for (int i = 0; i < cachedMessageCount; i++) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||
|
||||
expectedMessages.add(envelope);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user