diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisher.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisher.java index 56e4ddfc8..85d8f053b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisher.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisher.java @@ -10,6 +10,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; @@ -25,6 +26,7 @@ import java.util.function.BiConsumer; class FoundationDbMessagePublisher { private final Database database; + private final MessageGuidCodec messageGuidCodec; /// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption private final int maxMessagesPerScan; /// The end key at which we stop reading messages. For finite publisher, this is just past the end-of-queue key at the @@ -88,15 +90,18 @@ class FoundationDbMessagePublisher { private static final Logger LOGGER = LoggerFactory.getLogger(FoundationDbMessagePublisher.class); FoundationDbMessagePublisher( - final KeySelector beginKeyInclusive, - final KeySelector endKeyExclusive, - final Database database, - final int maxMessagesPerScan, - @Nullable final byte[] messagesAvailableWatchKey, - @Nullable final BiConsumer stateChangeListener) { + final KeySelector beginKeyInclusive, + final KeySelector endKeyExclusive, + final Database database, + final MessageGuidCodec messageGuidCodec, + final int maxMessagesPerScan, + @Nullable final byte[] messagesAvailableWatchKey, + @Nullable final BiConsumer stateChangeListener) { + this.beginKeyCursor = beginKeyInclusive; this.endKeyExclusive = endKeyExclusive; this.database = database; + this.messageGuidCodec = messageGuidCodec; this.maxMessagesPerScan = maxMessagesPerScan; this.messagesAvailableWatchKey = messagesAvailableWatchKey; this.terminateOnQueueEmpty = messagesAvailableWatchKey == null; @@ -116,8 +121,16 @@ class FoundationDbMessagePublisher { final KeySelector beginKeyInclusive, final KeySelector endKeyExclusive, final Database database, + final MessageGuidCodec messageGuidCodec, final int maxMessagesPerScan) { - return new FoundationDbMessagePublisher(beginKeyInclusive, endKeyExclusive, database, maxMessagesPerScan, null, null); + + return new FoundationDbMessagePublisher(beginKeyInclusive, + endKeyExclusive, + database, + messageGuidCodec, + maxMessagesPerScan, + null, + null); } /// Creates a [FoundationDbMessagePublisher] that publishes a non-terminating stream of messages from a device queue. @@ -127,9 +140,17 @@ class FoundationDbMessagePublisher { final KeySelector beginKeyInclusive, final KeySelector endKeyExclusive, final Database database, + final MessageGuidCodec messageGuidCodec, final int maxMessagesPerScan, final byte[] messagesAvailableWatchKey) { - return new FoundationDbMessagePublisher(beginKeyInclusive, endKeyExclusive, database, maxMessagesPerScan, messagesAvailableWatchKey, null); + + return new FoundationDbMessagePublisher(beginKeyInclusive, + endKeyExclusive, + database, + messageGuidCodec, + maxMessagesPerScan, + messagesAvailableWatchKey, + null); } private synchronized void setState(final State newState, final Event event) { @@ -222,7 +243,7 @@ class FoundationDbMessagePublisher { /// /// @return a future of a list of [MessageStreamEntry] with a max size of [#maxMessagesPerScan] private CompletableFuture> getMessagesBatch() { - return database.runAsync(transaction -> getItemsInRange(transaction, beginKeyCursor, endKeyExclusive, maxMessagesPerScan) + return database.runAsync(transaction -> getItemsInRange(transaction, messageGuidCodec, beginKeyCursor, endKeyExclusive, maxMessagesPerScan) .thenApply(lastKeyReadAndItems -> { // Set our beginning key to just past the last key read so that we're ready for our next fetch lastKeyReadAndItems.first().ifPresent(lastKeyRead -> beginKeyCursor = KeySelector.firstGreaterThan(lastKeyRead)); @@ -248,6 +269,7 @@ class FoundationDbMessagePublisher { /// @return the last key read (if there were non-zero number of messages read) and the list of messages read private static CompletableFuture, List>> getItemsInRange( final Transaction transaction, + final MessageGuidCodec messageGuidCodec, final KeySelector beginInclusive, final KeySelector endExclusive, final int maxMessagesPerScan) { @@ -259,7 +281,10 @@ class FoundationDbMessagePublisher { final List messages = keyValues.stream() .map(keyValue -> { try { - return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue())); + return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue()) + .toBuilder() + .setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey())))) + .build()); } catch (final InvalidProtocolBufferException e) { throw new UncheckedIOException(e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index ce04355fb..132aac0b6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.util.Conversions; public class FoundationDbMessageStore { private final Database[] databases; + private final VersionstampUUIDCipher versionstampUUIDCipher; private final Executor executor; private final Clock clock; @@ -60,8 +61,13 @@ public class FoundationDbMessageStore { public record InsertResult(Optional versionstamp, boolean present) { } - public FoundationDbMessageStore(final Database[] databases, final Executor executor, final Clock clock) { + public FoundationDbMessageStore(final Database[] databases, + final VersionstampUUIDCipher versionstampUUIDCipher, + final Executor executor, + final Clock clock) { + this.databases = databases; + this.versionstampUUIDCipher = versionstampUUIDCipher; this.executor = executor; this.clock = clock; } @@ -263,9 +269,14 @@ public class FoundationDbMessageStore { return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()), getMessagesAvailableWatchKey(aci), getShardForAci(aci), + new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher), maxMessagesPerScan); } + static Versionstamp getVersionstamp(final byte[] messageKey) { + return Tuple.fromBytes(messageKey).getVersionstamp(4); + } + @VisibleForTesting Database getShardForAci(final AciServiceIdentifier aci) { return databases[hashAciToShardNumber(aci)]; @@ -278,20 +289,20 @@ public class FoundationDbMessageStore { } @VisibleForTesting - Subspace getDeviceQueueSubspace(final AciServiceIdentifier aci, final byte deviceId) { + static Subspace getDeviceQueueSubspace(final AciServiceIdentifier aci, final byte deviceId) { return getDeviceSubspace(aci, deviceId).get("Q"); } - private Subspace getDeviceSubspace(final AciServiceIdentifier aci, final byte deviceId) { + private static Subspace getDeviceSubspace(final AciServiceIdentifier aci, final byte deviceId) { return getAccountSubspace(aci).get(deviceId); } - private Subspace getAccountSubspace(final AciServiceIdentifier aci) { + private static Subspace getAccountSubspace(final AciServiceIdentifier aci) { return MESSAGES_SUBSPACE.get(aci.uuid()); } @VisibleForTesting - byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) { + static byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) { return getAccountSubspace(aci).pack("l"); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java index 9b745f6a0..ff66d6836 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStream.java @@ -21,17 +21,23 @@ public class FoundationDbMessageStream implements MessageStream { private final Subspace deviceQueueSubspace; private final byte[] messagesAvailableWatchKey; private final Database database; + private final MessageGuidCodec messageGuidCodec; /// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption private final int maxMessagesPerScan; private final Flow.Publisher messageStreamPublisher; static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024; - FoundationDbMessageStream(final Subspace deviceQueueSubspace, final byte[] messagesAvailableWatchKey, - final Database database, final int maxMessagesPerScan) { + FoundationDbMessageStream(final Subspace deviceQueueSubspace, + final byte[] messagesAvailableWatchKey, + final Database database, + final MessageGuidCodec messageGuidCodec, + final int maxMessagesPerScan) { + this.deviceQueueSubspace = deviceQueueSubspace; this.messagesAvailableWatchKey = messagesAvailableWatchKey; this.database = database; + this.messageGuidCodec = messageGuidCodec; this.maxMessagesPerScan = maxMessagesPerScan; this.messageStreamPublisher = JdkFlowAdapter.publisherToFlowPublisher(createMessagePublisher()); } @@ -59,13 +65,13 @@ public class FoundationDbMessageStream implements MessageStream { final Flux finitePublisher = maybeEndOfQueueKeyExclusive .map(endOfQueueKeyExclusive -> FoundationDbMessagePublisher.createFinitePublisher( KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin), - endOfQueueKeyExclusive, database, maxMessagesPerScan).getMessages()) + endOfQueueKeyExclusive, database, messageGuidCodec, maxMessagesPerScan).getMessages()) .orElseGet(Flux::empty); final KeySelector infinitePublisherBeginKey = maybeEndOfQueueKeyExclusive.orElseGet( () -> KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin)); final Flux infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher( infinitePublisherBeginKey, KeySelector.firstGreaterThan(deviceQueueSubspace.range().end), - database, maxMessagesPerScan, messagesAvailableWatchKey).getMessages(); + database, messageGuidCodec, maxMessagesPerScan, messagesAvailableWatchKey).getMessages(); return Flux.concat( finitePublisher, Mono.just(new MessageStreamEntry.QueueEmpty()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/MessageGuidCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/MessageGuidCodec.java new file mode 100644 index 000000000..fa182af78 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/MessageGuidCodec.java @@ -0,0 +1,33 @@ +/* + * Copyright 2026 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage.foundationdb; + +import com.apple.foundationdb.tuple.Versionstamp; +import java.util.UUID; + +class MessageGuidCodec { + + private final UUID accountIdentifier; + private final byte deviceId; + private final VersionstampUUIDCipher versionstampUUIDCipher; + + MessageGuidCodec(final UUID accountIdentifier, + final byte deviceId, + final VersionstampUUIDCipher versionstampUUIDCipher) { + + this.accountIdentifier = accountIdentifier; + this.deviceId = deviceId; + this.versionstampUUIDCipher = versionstampUUIDCipher; + } + + public UUID encodeMessageGuid(final Versionstamp versionstamp) { + return versionstampUUIDCipher.encryptVersionstamp(versionstamp, accountIdentifier, deviceId); + } + + public Versionstamp decodeMessageGuid(final UUID messageGuid) { + return versionstampUUIDCipher.decryptVersionstamp(messageGuid, accountIdentifier, deviceId); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisherTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisherTest.java index 4a088e9b6..badde85af 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisherTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessagePublisherTest.java @@ -16,36 +16,58 @@ import com.apple.foundationdb.Range; import com.apple.foundationdb.StreamingMode; import com.apple.foundationdb.Transaction; import com.apple.foundationdb.async.AsyncIterable; +import java.nio.ByteBuffer; +import java.security.SecureRandom; import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.foundationdb.tuple.Versionstamp; +import com.google.protobuf.InvalidProtocolBufferException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; +import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.test.StepVerifier; /// NOTE: most of the happy-path test cases are already covered in {@link FoundationDbMessageStoreTest}, this test /// mostly exercises edge-cases and error handling that are hard to test without mocks class FoundationDbMessagePublisherTest { - private final Range subspaceRange = new Range(new byte[]{(byte) 0}, new byte[]{(byte) 100}); - private final byte[] messagesAvailableWatchKey = new byte[]{(byte) 42}; - private Database database; + private MessageGuidCodec messageGuidCodec; private List stateTransitions; + private static final AciServiceIdentifier SERVICE_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID()); + + private static final Range SUBSPACE_RANGE = + FoundationDbMessageStore.getDeviceQueueSubspace(SERVICE_IDENTIFIER, Device.PRIMARY_ID).range(); + + private static final byte[] MESSAGES_AVAILABLE_WATCH_KEY = + FoundationDbMessageStore.getMessagesAvailableWatchKey(SERVICE_IDENTIFIER); + @BeforeEach void setUp() { database = mock(Database.class); stateTransitions = new ArrayList<>(); + + final byte[] messageGuidCodecKey = new byte[16]; + new SecureRandom().nextBytes(messageGuidCodecKey); + + messageGuidCodec = new MessageGuidCodec(SERVICE_IDENTIFIER.uuid(), + Device.PRIMARY_ID, + new VersionstampUUIDCipher(0, messageGuidCodecKey)); } @Test - void finitePublisherMultipleBatches() { + void finitePublisherMultipleBatches() throws InvalidProtocolBufferException { final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false); final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false); final MessageProtos.Envelope message3 = FoundationDbMessageStoreTest.generateRandomMessage(false); @@ -74,18 +96,19 @@ class FoundationDbMessagePublisherTest { }); final FoundationDbMessagePublisher finitePublisher = new FoundationDbMessagePublisher( - KeySelector.firstGreaterOrEqual(subspaceRange.begin), - KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}), + KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin), + KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.end), database, + messageGuidCodec, 2, // With 3 messages and batch size set to 2, we'll need to grab 2 batches. null, - (oldState, newState) -> stateTransitions.add(newState) + (_, newState) -> stateTransitions.add(newState) ); StepVerifier.create(finitePublisher.getMessages()) - .expectNext(new MessageStreamEntry.Envelope(message1)) - .expectNext(new MessageStreamEntry.Envelope(message2)) - .expectNext(new MessageStreamEntry.Envelope(message3)) + .expectNext(getExpectedMessageStreamEntry(keyValue1)) + .expectNext(getExpectedMessageStreamEntry(keyValue2)) + .expectNext(getExpectedMessageStreamEntry(keyValue3)) .verifyComplete(); assertEquals(List.of( @@ -102,7 +125,7 @@ class FoundationDbMessagePublisherTest { @Test @SuppressWarnings({"unchecked", "resource"}) - void infinitePublisher() { + void infinitePublisher() throws InvalidProtocolBufferException { final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false); final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false); final MessageProtos.Envelope message3 = FoundationDbMessageStoreTest.generateRandomMessage(false); @@ -138,17 +161,18 @@ class FoundationDbMessagePublisherTest { final CompletableFuture watchFuture1 = new CompletableFuture<>(); final CompletableFuture watchFuture2 = new CompletableFuture<>(); final CompletableFuture watchFuture3 = new CompletableFuture<>(); // this one will not be completed - when(transaction.watch(messagesAvailableWatchKey)) + when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY)) .thenReturn(watchFuture1) .thenReturn(watchFuture2) .thenReturn(watchFuture3); final FoundationDbMessagePublisher infinitePublisher = new FoundationDbMessagePublisher( - KeySelector.firstGreaterOrEqual(subspaceRange.begin), + KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin), KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}), database, + messageGuidCodec, 2, - messagesAvailableWatchKey, + MESSAGES_AVAILABLE_WATCH_KEY, (oldState, newState) -> { stateTransitions.add(newState); if (newState == FoundationDbMessagePublisher.State.AWAITING_NEW_MESSAGES) { @@ -167,9 +191,9 @@ class FoundationDbMessagePublisherTest { ); StepVerifier.create(infinitePublisher.getMessages()) - .expectNext(new MessageStreamEntry.Envelope(message1)) - .expectNext(new MessageStreamEntry.Envelope(message2)) - .expectNext(new MessageStreamEntry.Envelope(message3)) + .expectNext(getExpectedMessageStreamEntry(keyValue1)) + .expectNext(getExpectedMessageStreamEntry(keyValue2)) + .expectNext(getExpectedMessageStreamEntry(keyValue3)) .verifyTimeout(Duration.ofSeconds(1)); assertEquals(List.of( @@ -192,7 +216,7 @@ class FoundationDbMessagePublisherTest { } @Test - void messageAvailableWatchSignalBuffered() { + void messageAvailableWatchSignalBuffered() throws InvalidProtocolBufferException { final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false); final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false); @@ -220,16 +244,17 @@ class FoundationDbMessagePublisherTest { final CompletableFuture watchFuture1 = new CompletableFuture<>(); final CompletableFuture watchFuture2 = new CompletableFuture<>(); // this one will not be completed - when(transaction.watch(messagesAvailableWatchKey)) + when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY)) .thenReturn(watchFuture1) .thenReturn(watchFuture2); final FoundationDbMessagePublisher infinitePublisher = new FoundationDbMessagePublisher( - KeySelector.firstGreaterOrEqual(subspaceRange.begin), + KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin), KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}), database, + messageGuidCodec, 2, - messagesAvailableWatchKey, + MESSAGES_AVAILABLE_WATCH_KEY, (oldState, newState) -> { stateTransitions.add(newState); // Simulate an edge case where the messages available watch could trigger right after queue empty, but before @@ -246,8 +271,8 @@ class FoundationDbMessagePublisherTest { ); StepVerifier.create(infinitePublisher.getMessages()) - .expectNext(new MessageStreamEntry.Envelope(message1)) - .expectNext(new MessageStreamEntry.Envelope(message2)) + .expectNext(getExpectedMessageStreamEntry(keyValue1)) + .expectNext(getExpectedMessageStreamEntry(keyValue2)) .verifyTimeout(Duration.ofSeconds(1)); assertEquals(List.of( @@ -268,25 +293,24 @@ class FoundationDbMessagePublisherTest { @Test @SuppressWarnings({"unchecked", "resource"}) - void watchCanceledOnSubscriptionCancel() { + void watchCanceledOnSubscriptionCancel() throws InvalidProtocolBufferException { final FoundationDbMessagePublisher infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher( - KeySelector.firstGreaterOrEqual(subspaceRange.begin), - KeySelector.firstGreaterThan(subspaceRange.end), + KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin), + KeySelector.firstGreaterThan(SUBSPACE_RANGE.end), database, + messageGuidCodec, 100, - messagesAvailableWatchKey); + MESSAGES_AVAILABLE_WATCH_KEY); final MessageProtos.Envelope message = FoundationDbMessageStoreTest.generateRandomMessage(false); final Transaction transaction = mock(Transaction.class); - final KeyValue keyValue = mock(KeyValue.class); - when(keyValue.getKey()).thenReturn(new byte[]{(byte) 5}); - when(keyValue.getValue()).thenReturn(message.toByteArray()); + final KeyValue keyValue = mockKeyValue((byte) 5, message); final AsyncIterable asyncIterable = mock(AsyncIterable.class); when(asyncIterable.asList()).thenReturn(CompletableFuture.completedFuture(List.of(keyValue))); when(transaction.getRange(any(KeySelector.class), any(KeySelector.class), anyInt(), anyBoolean(), any( StreamingMode.class))) .thenReturn(asyncIterable); final CompletableFuture watchFuture = mock(CompletableFuture.class); - when(transaction.watch(messagesAvailableWatchKey)).thenReturn(watchFuture); + when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY)).thenReturn(watchFuture); when(database.runAsync(any(Function.class))).thenAnswer( (Answer>>) invocationOnMock -> { @@ -295,16 +319,29 @@ class FoundationDbMessagePublisherTest { return f.apply(transaction); }); StepVerifier.create(infinitePublisher.getMessages()) - .expectNext(new MessageStreamEntry.Envelope(message)) + .expectNext(getExpectedMessageStreamEntry(keyValue)) .thenCancel() .verify(Duration.ofMillis(100)); verify(watchFuture).cancel(true); } - private KeyValue mockKeyValue(final byte key, final MessageProtos.Envelope message) { + private MessageStreamEntry.Envelope getExpectedMessageStreamEntry(final KeyValue keyValue) + throws InvalidProtocolBufferException { + return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue()) + .toBuilder() + .setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey())))) + .build()); + } + + private static KeyValue mockKeyValue(final byte key, final MessageProtos.Envelope message) { + final ByteBuffer versionstampBuffer = ByteBuffer.allocate(Versionstamp.LENGTH); + versionstampBuffer.put(11, key); + final KeyValue keyValue = mock(KeyValue.class); - when(keyValue.getKey()).thenReturn(new byte[]{key}); + when(keyValue.getKey()) + .thenReturn(FoundationDbMessageStore.getDeviceQueueSubspace(SERVICE_IDENTIFIER, Device.PRIMARY_ID) + .pack(Tuple.from(Versionstamp.fromBytes(versionstampBuffer.array())))); when(keyValue.getValue()).thenReturn(message.toByteArray()); return keyValue; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index d2396b05f..e324707a3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.storage.foundationdb; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -13,11 +14,14 @@ import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.util.DataSize; import java.io.UncheckedIOException; +import java.security.SecureRandom; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -30,11 +34,11 @@ import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; -import io.dropwizard.util.DataSize; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -51,6 +55,7 @@ import org.whispersystems.textsecuregcm.storage.MessageStream; import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import org.whispersystems.textsecuregcm.util.UUIDUtil; import reactor.adapter.JdkFlowAdapter; import reactor.test.StepVerifier; @@ -60,14 +65,21 @@ class FoundationDbMessageStoreTest { @RegisterExtension static FoundationDbClusterExtension FOUNDATION_DB_EXTENSION = new FoundationDbClusterExtension(2); + private VersionstampUUIDCipher versionstampUUIDCipher; private FoundationDbMessageStore foundationDbMessageStore; private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC")); @BeforeEach void setup() { + final byte[] versionstampCipherKey = new byte[16]; + new SecureRandom().nextBytes(versionstampCipherKey); + + versionstampUUIDCipher = new VersionstampUUIDCipher(0, versionstampCipherKey); + foundationDbMessageStore = new FoundationDbMessageStore( FOUNDATION_DB_EXTENSION.getDatabases(), + versionstampUUIDCipher, Executors.newVirtualThreadPerTaskExecutor(), CLOCK); } @@ -382,18 +394,34 @@ class FoundationDbMessageStoreTest { void getMessages(final int numMessages, final int batchSize) { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); writePresenceKey(aci, Device.PRIMARY_ID, 1, 5L); - for (int i = 0; i < numMessages; i++) { - final MessageProtos.Envelope message = generateRandomMessage(false); - assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message)).join()); - } + + final List expectedVersionstamps = IntStream.range(0, numMessages) + .mapToObj(_ -> foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join() + .get(Device.PRIMARY_ID) + .versionstamp() + .orElseThrow()) + .toList(); final Device device = new Device(); device.setId(Device.PRIMARY_ID); final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize); + final List retrievedEntries = new ArrayList<>(); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) + .recordWith(() -> retrievedEntries) .expectNextCount(numMessages) .expectNext(new MessageStreamEntry.QueueEmpty()) .verifyTimeout(Duration.ofSeconds(1)); + + final MessageGuidCodec messageGuidCodec = + new MessageGuidCodec(aci.uuid(), Device.PRIMARY_ID, versionstampUUIDCipher); + + for (int i = 0; i < expectedVersionstamps.size(); i++) { + final MessageStreamEntry.Envelope envelopeEntry = + assertInstanceOf(MessageStreamEntry.Envelope.class, retrievedEntries.get(i)); + + assertEquals(expectedVersionstamps.get(i), + messageGuidCodec.decodeMessageGuid(UUIDUtil.fromByteString(envelopeEntry.message().getServerGuidBinary()))); + } } static Stream getMessages() { @@ -410,20 +438,38 @@ class FoundationDbMessageStoreTest { @Test void getMessagesPublishMoreAfterQueueEmpty() { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); + final MessageGuidCodec messageGuidCodec = + new MessageGuidCodec(aci.uuid(), Device.PRIMARY_ID, versionstampUUIDCipher); + writePresenceKey(aci, Device.PRIMARY_ID, 1, 5L); final MessageProtos.Envelope message1 = generateRandomMessage(false); - assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message1)).join()); + final Versionstamp versionstamp1 = foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message1)).join() + .get(Device.PRIMARY_ID) + .versionstamp() + .orElseThrow(); + final MessageProtos.Envelope message2 = generateRandomMessage(false); - assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message2)).join()); + final Versionstamp versionstamp2 = foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message2)).join() + .get(Device.PRIMARY_ID) + .versionstamp() + .orElseThrow(); final CountDownLatch latch = new CountDownLatch(1); final MessageProtos.Envelope message3 = generateRandomMessage(false); + final AtomicReference versionstamp3 = new AtomicReference<>(); Thread.ofVirtual().start(() -> { try { // Wait until queue is empty assertTrue(latch.await(1000, TimeUnit.MILLISECONDS)); // Then publish more messages - assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message3)).join()); + synchronized (versionstamp3) { + versionstamp3.set(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message3)).join() + .get(Device.PRIMARY_ID) + .versionstamp() + .orElseThrow()); + + versionstamp3.notifyAll(); + } } catch (final InterruptedException e) { fail(e); } @@ -433,11 +479,34 @@ class FoundationDbMessageStoreTest { device.setId(Device.PRIMARY_ID); final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device); StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())) - .expectNext(new MessageStreamEntry.Envelope(message1)) - .expectNext(new MessageStreamEntry.Envelope(message2)) + .expectNext(new MessageStreamEntry.Envelope(message1 + .toBuilder() + .setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp1))) + .build())) + .expectNext(new MessageStreamEntry.Envelope(message2 + .toBuilder() + .setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp2))) + .build())) .expectNext(new MessageStreamEntry.QueueEmpty()) - .then(latch::countDown) - .expectNext(new MessageStreamEntry.Envelope(message3)) + .then(() -> { + // Trigger insertion of another message + latch.countDown(); + + // …but then wait for its versionstamp so we can verify that we have the right payload + synchronized (versionstamp3) { + while (versionstamp3.get() == null) { + try { + versionstamp3.wait(); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + } + } + }) + .expectNextMatches(entry -> entry.equals(new MessageStreamEntry.Envelope(message3 + .toBuilder() + .setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp3.get()))) + .build()))) .verifyTimeout(Duration.ofSeconds(3)); }