Multi-recipient message views

This adds support for storing multi-recipient message payloads and recipient views in Redis, and only fanning out on delivery or persistence. Phase 1: confirm storage and retrieval correctness.
This commit is contained in:
Chris Eager
2024-09-04 13:58:20 -05:00
committed by GitHub
parent d78c8370b6
commit 11601fd091
50 changed files with 1544 additions and 328 deletions

View File

@@ -87,6 +87,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
@@ -121,6 +122,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.RemovedMessage;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@@ -251,6 +253,7 @@ class MessageControllerTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getInboundMessageByteLimitConfiguration()).thenReturn(inboundMessageByteLimitConfiguration);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
@@ -311,7 +314,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
assertTrue(captor.getValue().getUrgent());
}
@@ -353,7 +356,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
assertFalse(captor.getValue().getUrgent());
}
@@ -375,7 +378,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertTrue(captor.getValue().hasSourceUuid());
assertTrue(captor.getValue().hasSourceServiceId());
assertTrue(captor.getValue().hasSourceDevice());
}
}
@@ -410,7 +413,7 @@ class MessageControllerTest {
ArgumentCaptor<Envelope> captor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false));
assertFalse(captor.getValue().hasSourceUuid());
assertFalse(captor.getValue().hasSourceServiceId());
assertFalse(captor.getValue().hasSourceDevice());
}
}
@@ -444,7 +447,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) {
verify(messageSender).sendMessage(
any(Account.class), any(Device.class), argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
any(Account.class), any(Device.class), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(false));
} else {
verifyNoMoreInteractions(messageSender);
@@ -732,23 +735,27 @@ class MessageControllerTest {
@Test
void testDeleteMessages() {
long timestamp = System.currentTimeMillis();
long clientTimestamp = System.currentTimeMillis();
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
final long serverTimestamp = 0;
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid1, null))
.thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid1, serverTimestamp, clientTimestamp,
Envelope.Type.CIPHERTEXT))));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid2, null))
.thenReturn(
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0))));
CompletableFutureTestUtil.almostCompletedFuture(Optional.of(
new RemovedMessage(Optional.of(new AciServiceIdentifier(sourceUuid)),
new AciServiceIdentifier(AuthHelper.VALID_UUID), uuid2, serverTimestamp, clientTimestamp,
Envelope.Type.SERVER_DELIVERY_RECEIPT))));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE, uuid3, null))
@@ -766,7 +773,7 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
eq(new AciServiceIdentifier(sourceUuid)), eq(clientTimestamp));
}
try (final Response response = resources.getJerseyTest()
@@ -1068,9 +1075,16 @@ class MessageControllerTest {
}
private record Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
Byte[] deviceId,
Integer[] registrationId,
byte[] perRecipientKeyMaterial) {
Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
this(uuid, new Byte[]{deviceId}, new Integer[]{registrationId}, perRecipientKeyMaterial);
}
}
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
@@ -1081,8 +1095,13 @@ class MessageControllerTest {
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
}
bb.put(r.deviceId()); // device id (1 byte)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
assert (r.deviceId.length == r.registrationId.length);
for (int i = 0; i < r.deviceId.length; i++) {
final int hasMore = i == r.deviceId.length - 1 ? 0 : 0x8000;
bb.put(r.deviceId()[i]); // device id (1 byte)
bb.putShort((short) (r.registrationId()[i] | hasMore)); // registration id (2 bytes)
}
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
@@ -1157,7 +1176,7 @@ class MessageControllerTest {
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.put(entity)) {
assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200)));
@@ -1206,7 +1225,7 @@ class MessageControllerTest {
.queryParam("story", isStory)
.queryParam("urgent", urgent)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
.put(entity)) {
@@ -1216,7 +1235,7 @@ class MessageControllerTest {
.sendMessage(
any(),
any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
@@ -1384,7 +1403,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@@ -1395,7 +1414,7 @@ class MessageControllerTest {
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()),
eq(true));
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
@@ -1423,7 +1442,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@@ -1454,7 +1473,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader(
serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z")))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) {
@@ -1620,7 +1639,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
@@ -1663,7 +1682,7 @@ class MessageControllerTest {
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
@@ -1702,7 +1721,7 @@ class MessageControllerTest {
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HttpHeaders.USER_AGENT, "test")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
@@ -1730,14 +1749,14 @@ class MessageControllerTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setStory(story)
.setServerGuid(guid.toString());
if (sourceUuid != null) {
builder.setSourceUuid(sourceUuid.toString());
builder.setSourceServiceId(sourceUuid.toString());
builder.setSourceDevice(sourceDevice);
}

View File

@@ -104,7 +104,7 @@ class MessageMetricsTest {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
if (destinationIdentifier != null) {
builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString());
builder.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString());
}
return builder.build();

View File

@@ -151,7 +151,7 @@ class MessageSenderTest {
private MessageProtos.Envelope generateRandomMessage() {
return MessageProtos.Envelope.newBuilder()
.setTimestamp(System.currentTimeMillis())
.setClientTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@@ -54,8 +54,8 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

View File

@@ -160,8 +160,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@@ -208,8 +208,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@@ -254,8 +254,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@@ -296,8 +296,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@@ -340,8 +340,8 @@ public class ChangeNumberManagerTest {
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}

View File

@@ -81,7 +81,7 @@ class MessagePersisterIntegrationTest {
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService,
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
@@ -185,12 +185,12 @@ class MessagePersisterIntegrationTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long serverTimestamp) {
return MessageProtos.Envelope.newBuilder()
.setTimestamp(serverTimestamp * 2) // client timestamp may not be accurate
.setClientTimestamp(serverTimestamp * 2) // client timestamp may not be accurate
.setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setDestinationServiceId(UUID.randomUUID().toString())
.build();
}
}

View File

@@ -40,6 +40,7 @@ import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
@@ -48,12 +49,11 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class MessagePersisterTest {
@RegisterExtension
@@ -104,7 +104,7 @@ class MessagePersisterTest {
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager,
keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1);
@@ -356,7 +356,8 @@ class MessagePersisterTest {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.lettuce.core.RedisCommandExecutionException;
import io.lettuce.core.ScriptOutputType;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheGetItemsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheGetItemsScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final String serverGuid = UUID.randomUUID().toString();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid)
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1)
.block(Duration.ofSeconds(1));
assertNotNull(messageAndScores);
assertEquals(2, messageAndScores.size());
final MessageProtos.Envelope resultEnvelope = MessageProtos.Envelope.parseFrom(
messageAndScores.getFirst());
assertEquals(serverGuid, resultEnvelope.getServerGuid());
}
@Test
void testCacheGetItemsInvalidParameter() throws Exception {
final ClusterLuaScript getItemsScript = ClusterLuaScript.fromResource(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"lua/get_items.lua", ScriptOutputType.OBJECT);
final byte[] fakeKey = new byte[]{1};
final Exception e = assertThrows(RedisCommandExecutionException.class,
() -> getItemsScript.executeBinaryReactive(List.of(fakeKey, fakeKey),
List.of("1".getBytes(StandardCharsets.UTF_8)))
.next()
.block(Duration.ofSeconds(1)));
assertEquals("ERR afterMessageId is required", e.getMessage());
}
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheInsertScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheInsertScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1));
final MessageProtos.Envelope envelope2 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
assertEquals(2, insertScript.execute(destinationUuid, deviceId, envelope2));
assertEquals(1, insertScript.execute(destinationUuid, deviceId, envelope1),
"Repeated with same guid should have same message ID");
}
}

View File

@@ -0,0 +1,74 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testInsert(final int count, final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().hlen(sharedMrmKey));
assertEquals(totalDevices + 1, hashFieldCount);
}
public static List<Arguments> testInsert() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Arguments> testCases = new ArrayList<>();
testCases.add(Arguments.of(1, singleAccount));
for (int j = 1000; j <= 30000; j += 1000) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(Arguments.of(j, manyAccounts));
}
return testCases;
}
}

View File

@@ -0,0 +1,52 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveByGuidScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveByGuid() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID serverGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(serverGuid.toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> removedMessages = removeByGuidScript.execute(destinationUuid, deviceId,
List.of(serverGuid)).get(1, TimeUnit.SECONDS);
assertEquals(1, removedMessages.size());
final MessageProtos.Envelope resultMessage = MessageProtos.Envelope.parseFrom(
removedMessages.getFirst());
assertEquals(serverGuid, UUID.fromString(resultMessage.getServerGuid()));
}
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheRemoveQueueScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
void testCacheRemoveQueueScript() throws Exception {
final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessageProtos.Envelope envelope1 = MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build();
insertScript.execute(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messagesToCheckForMrmKeys = removeScript.execute(destinationUuid, deviceId,
Collections.emptyList())
.block(Duration.ofSeconds(1));
assertEquals(1, messagesToCheckForMrmKeys.size());
}
}

View File

@@ -0,0 +1,124 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.lettuce.core.cluster.SlotHash;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuples;
class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ParameterizedTest
@MethodSource
void testUpdateSingleKey(final Map<AciServiceIdentifier, List<Byte>> destinations) throws Exception {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations));
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(destinations.entrySet())
.flatMap(e -> Flux.fromStream(e.getValue().stream().map(deviceId -> Tuples.of(e.getKey(), deviceId))))
.flatMap(aciServiceIdentifierByteTuple -> removeRecipientViewFromMrmDataScript.execute(List.of(sharedMrmKey),
aciServiceIdentifierByteTuple.getT1(), aciServiceIdentifierByteTuple.getT2()))
.reduce(Long::sum)
.block(Duration.ofSeconds(35)));
assertEquals(1, keysRemoved);
final long keyExists = REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmKey));
assertEquals(0, keyExists);
}
public static List<Map<AciServiceIdentifier, List<Byte>>> testUpdateSingleKey() {
final Map<AciServiceIdentifier, List<Byte>> singleAccount = Map.of(
new AciServiceIdentifier(UUID.randomUUID()), List.of((byte) 1, (byte) 2));
final List<Map<AciServiceIdentifier, List<Byte>>> testCases = new ArrayList<>();
testCases.add(singleAccount);
// Generate a more, from smallish to very large
for (int j = 1000; j <= 81000; j *= 3) {
final Map<Integer, List<Byte>> deviceLists = new HashMap<>();
final Map<AciServiceIdentifier, List<Byte>> manyAccounts = IntStream.range(0, j)
.mapToObj(i -> {
final int deviceCount = 1 + i % 5;
final List<Byte> devices = deviceLists.computeIfAbsent(deviceCount, count -> IntStream.rangeClosed(1, count)
.mapToObj(v -> (byte) v)
.toList());
return new Pair<>(new AciServiceIdentifier(UUID.randomUUID()), devices);
})
.collect(Collectors.toMap(Pair::first, Pair::second));
testCases.add(manyAccounts);
}
return testCases;
}
@ParameterizedTest
@ValueSource(ints = {1, 10, 100, 1000, 10000})
void testUpdateManyKeys(int keyCount) throws Exception {
final List<byte[]> sharedMrmKeys = new ArrayList<>(keyCount);
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final byte deviceId = 1;
for (int i = 0; i < keyCount; i++) {
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript = new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(aciServiceIdentifier, deviceId));
sharedMrmKeys.add(sharedMrmKey);
}
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final long keysRemoved = Objects.requireNonNull(Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(keys -> removeRecipientViewFromMrmDataScript.execute(keys, aciServiceIdentifier, deviceId))
.reduce(Long::sum)
.block(Duration.ofSeconds(5)));
assertEquals(sharedMrmKeys.size(), keysRemoved);
}
}

View File

@@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@@ -25,6 +26,8 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.RedisCommand;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
@@ -32,9 +35,12 @@ import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
@@ -42,11 +48,13 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -57,7 +65,12 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
@@ -83,6 +96,8 @@ class MessagesCacheTest {
private Scheduler messageDeliveryScheduler;
private MessagesCache messagesCache;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final byte DESTINATION_DEVICE_ID = 7;
@@ -95,11 +110,16 @@ class MessagesCacheTest {
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagesConfiguration()).thenReturn(new DynamicMessagesConfiguration(true, true));
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
sharedExecutorService = Executors.newSingleThreadExecutor();
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesCache.start();
}
@@ -148,10 +168,10 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<MessageProtos.Envelope> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
assertEquals(Optional.of(message), maybeRemovedMessage);
assertEquals(Optional.of(RemovedMessage.fromEnvelope(message)), maybeRemovedMessage);
}
@ParameterizedTest
@@ -181,11 +201,11 @@ class MessagesCacheTest {
message);
}
final List<MessageProtos.Envelope> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList())).get(5, TimeUnit.SECONDS);
assertEquals(messagesToRemove, removedMessages);
assertEquals(messagesToRemove.stream().map(RemovedMessage::fromEnvelope).toList(), removedMessages);
assertEquals(messagesToPreserve,
messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@@ -283,7 +303,8 @@ class MessagesCacheTest {
}
final MessagesCache messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock);
sharedExecutorService, messageDeliveryScheduler, sharedExecutorService, cacheClock,
dynamicConfigurationManager);
final List<MessageProtos.Envelope> actualMessages = Flux.from(
messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID))
@@ -320,7 +341,7 @@ class MessagesCacheTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
@@ -340,7 +361,7 @@ class MessagesCacheTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
final int messageCount = 1000;
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
@@ -542,6 +563,57 @@ class MessagesCacheTest {
});
}
@Test
void testMultiRecipientMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final UUID mrmGuid = UUID.randomUUID();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(
new AciServiceIdentifier(destinationUuid), deviceId);
final byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrmGuid, mrm);
final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(guid, true)
.toBuilder()
// clear some things added by the helper
.clearServerGuid()
// mrm views phase 1: messages have content
.setContent(
ByteString.copyFrom(mrm.messageForRecipient(mrm.getRecipients().get(new ServiceId.Aci(destinationUuid)))))
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
messagesCache.insert(guid, destinationUuid, deviceId, message);
assertEquals(1L, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid))));
final List<MessageProtos.Envelope> messages = get(destinationUuid, deviceId, 1);
assertEquals(1, messages.size());
assertEquals(guid, UUID.fromString(messages.getFirst().getServerGuid()));
assertFalse(messages.getFirst().hasSharedMrmKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = mrm.getRecipients()
.get(new ServiceId.Aci(destinationUuid));
assertArrayEquals(mrm.messageForRecipient(recipient), messages.getFirst().getContent().toByteArray());
final Optional<RemovedMessage> removedMessage = messagesCache.remove(destinationUuid, deviceId, guid)
.join();
assertTrue(removedMessage.isPresent());
assertEquals(guid, UUID.fromString(removedMessage.get().serverGuid().toString()));
assertTrue(get(destinationUuid, deviceId, 1).isEmpty());
// updating the shared MRM data is purely async, so we just wait for it
assertTimeoutPreemptively(Duration.ofSeconds(1), () -> {
boolean exists;
do {
exists = 1 == REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(MessagesCache.getSharedMrmKey(mrmGuid)));
} while (exists);
});
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
final int messageCount) {
return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
@@ -573,7 +645,7 @@ class MessagesCacheTest {
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
messagesCache = new MessagesCache(mockCluster, mock(ExecutorService.class), messageDeliveryScheduler,
Executors.newSingleThreadExecutor(), Clock.systemUTC());
Executors.newSingleThreadExecutor(), Clock.systemUTC(), mock(DynamicConfigurationManager.class));
}
@AfterEach
@@ -755,18 +827,85 @@ class MessagesCacheTest {
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender,
final long timestamp) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString());
.setDestinationServiceId(UUID.randomUUID().toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
.setSourceUuid(UUID.randomUUID().toString());
.setSourceServiceId(UUID.randomUUID().toString());
}
return envelopeBuilder.build();
}
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(
Map<AciServiceIdentifier, List<Byte>> destinations) {
try {
final ByteBuffer prefix = ByteBuffer.allocate(7);
prefix.put((byte) 0x23); // version
writeVarint(prefix, destinations.size()); // recipient count
prefix.flip();
List<ByteBuffer> recipients = new ArrayList<>(destinations.size());
for (Map.Entry<AciServiceIdentifier, List<Byte>> aciServiceIdentifierAndDeviceIds : destinations.entrySet()) {
final AciServiceIdentifier destination = aciServiceIdentifierAndDeviceIds.getKey();
final List<Byte> deviceIds = aciServiceIdentifierAndDeviceIds.getValue();
assert deviceIds.size() < 255;
final ByteBuffer recipient = ByteBuffer.allocate(17 + 3 * deviceIds.size() + 48);
recipient.put(destination.toFixedWidthByteArray());
for (int i = 0; i < deviceIds.size(); i++) {
final int hasMore = i == deviceIds.size() - 1 ? 0x0000 : 0x8000;
recipient.put(new byte[]{deviceIds.get(i)}); // device ID
recipient.putShort((short) ((100 + deviceIds.get(i)) | hasMore)); // registration ID
}
final byte[] keyMaterial = new byte[48];
ThreadLocalRandom.current().nextBytes(keyMaterial);
recipient.put(keyMaterial);
recipients.add(recipient);
}
final byte[] commonPayload = new byte[64];
ThreadLocalRandom.current().nextBytes(commonPayload);
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
baos.write(prefix.array(), 0, prefix.limit());
for (ByteBuffer recipient : recipients) {
baos.write(recipient.array());
}
baos.write(commonPayload);
return SealedSenderMultiRecipientMessage.parse(baos.toByteArray());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
static SealedSenderMultiRecipientMessage generateRandomMrmMessage(AciServiceIdentifier destination,
byte... deviceIds) {
final Map<AciServiceIdentifier, List<Byte>> destinations = new HashMap<>();
destinations.put(destination, Arrays.asList(ArrayUtils.toObject(deviceIds)));
return generateRandomMrmMessage(destinations);
}
private static void writeVarint(ByteBuffer bb, long n) {
while (n >= 0x80) {
bb.put((byte) (n & 0x7F | 0x80));
n = n >> 7;
}
bb.put((byte) (n & 0x7F));
}
}

View File

@@ -6,8 +6,6 @@
package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.time.Duration;
@@ -31,7 +29,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.MessageHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
@@ -47,31 +44,31 @@ class MessagesDynamoDbTest {
final long serverTimestamp = System.currentTimeMillis();
MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.setTimestamp(123456789L);
builder.setClientTimestamp(123456789L);
builder.setContent(ByteString.copyFrom(new byte[]{(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp);
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE1 = builder.build();
builder.setType(MessageProtos.Envelope.Type.CIPHERTEXT);
builder.setSourceUuid(UUID.randomUUID().toString());
builder.setSourceServiceId(UUID.randomUUID().toString());
builder.setSourceDevice(1);
builder.setContent(ByteString.copyFromUtf8("MOO"));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp + 1);
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE2 = builder.build();
builder.setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER);
builder.clearSourceUuid();
builder.clearSourceDevice();
builder.clearSourceDevice();
builder.setContent(ByteString.copyFromUtf8("COW"));
builder.setServerGuid(UUID.randomUUID().toString());
builder.setServerTimestamp(serverTimestamp); // Test same millisecond arrival for two different messages
builder.setDestinationUuid(UUID.randomUUID().toString());
builder.setDestinationServiceId(UUID.randomUUID().toString());
MESSAGE3 = builder.build();
}

View File

@@ -35,7 +35,7 @@ class MessagesManagerTest {
void insert() {
final UUID sourceAci = UUID.randomUUID();
final Envelope message = Envelope.newBuilder()
.setSourceUuid(sourceAci.toString())
.setSourceServiceId(sourceAci.toString())
.build();
final UUID destinationUuid = UUID.randomUUID();
@@ -45,7 +45,7 @@ class MessagesManagerTest {
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
final Envelope syncMessage = Envelope.newBuilder(message)
.setSourceUuid(destinationUuid.toString())
.setSourceServiceId(destinationUuid.toString())
.build();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);

View File

@@ -17,11 +17,11 @@ public class MessageHelper {
return MessageProtos.Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(0)
.setSourceUuid(senderUuid.toString())
.setSourceServiceId(senderUuid.toString())
.setSourceDevice(senderDeviceId)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}

View File

@@ -44,6 +44,7 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
@@ -55,6 +56,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
@@ -85,6 +87,8 @@ class WebSocketConnectionIntegrationTest {
private Scheduler messageDeliveryScheduler;
private ClientReleaseManager clientReleaseManager;
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private long serialTimestamp = System.currentTimeMillis();
@BeforeEach
@@ -92,8 +96,10 @@ class WebSocketConnectionIntegrationTest {
sharedExecutorService = Executors.newSingleThreadExecutor();
scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService,
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC());
messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC(), dynamicConfigurationManager);
messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.MESSAGES.tableName(), Duration.ofDays(7),
sharedExecutorService);
@@ -381,12 +387,12 @@ class WebSocketConnectionIntegrationTest {
final long timestamp = serialTimestamp++;
return MessageProtos.Envelope.newBuilder()
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setDestinationServiceId(UUID.randomUUID().toString())
.build();
}

View File

@@ -48,7 +48,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -297,19 +296,19 @@ class WebSocketConnectionTest {
final Envelope firstMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(UUID.randomUUID().toString())
.setDestinationUuid(accountUuid.toString())
.setSourceServiceId(UUID.randomUUID().toString())
.setDestinationServiceId(accountUuid.toString())
.setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
.build();
final Envelope secondMessage = Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(senderTwoUuid.toString())
.setDestinationUuid(accountUuid.toString())
.setTimestamp(System.currentTimeMillis())
.setSourceServiceId(senderTwoUuid.toString())
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT)
.build();
@@ -365,7 +364,7 @@ class WebSocketConnectionTest {
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getTimestamp()));
eq(secondMessage.getClientTimestamp()));
connection.stop();
verify(client).close(anyInt(), anyString());
@@ -616,10 +615,10 @@ class WebSocketConnectionTest {
final byte[] body = argument.get();
try {
final Envelope envelope = Envelope.parseFrom(body);
if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) {
if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) {
return false;
}
return envelope.getSourceUuid().equals(senderUuid.toString());
return envelope.getSourceServiceId().equals(senderUuid.toString());
} catch (InvalidProtocolBufferException e) {
return false;
}
@@ -627,7 +626,7 @@ class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class));
@@ -933,11 +932,11 @@ class WebSocketConnectionTest {
return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(0)
.setSourceUuid(senderUuid.toString())
.setSourceServiceId(senderUuid.toString())
.setSourceDevice(SOURCE_DEVICE_ID)
.setDestinationUuid(destinationUuid.toString())
.setDestinationServiceId(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}