Encode message versionstamps as GUIDs

This commit is contained in:
Jon Chambers
2026-04-01 11:20:37 -04:00
committed by Jon Chambers
parent d2cbdd4609
commit fb455bf1db
6 changed files with 246 additions and 65 deletions

View File

@@ -10,6 +10,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
@@ -25,6 +26,7 @@ import java.util.function.BiConsumer;
class FoundationDbMessagePublisher {
private final Database database;
private final MessageGuidCodec messageGuidCodec;
/// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption
private final int maxMessagesPerScan;
/// The end key at which we stop reading messages. For finite publisher, this is just past the end-of-queue key at the
@@ -88,15 +90,18 @@ class FoundationDbMessagePublisher {
private static final Logger LOGGER = LoggerFactory.getLogger(FoundationDbMessagePublisher.class);
FoundationDbMessagePublisher(
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final int maxMessagesPerScan,
@Nullable final byte[] messagesAvailableWatchKey,
@Nullable final BiConsumer<State, State> stateChangeListener) {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan,
@Nullable final byte[] messagesAvailableWatchKey,
@Nullable final BiConsumer<State, State> stateChangeListener) {
this.beginKeyCursor = beginKeyInclusive;
this.endKeyExclusive = endKeyExclusive;
this.database = database;
this.messageGuidCodec = messageGuidCodec;
this.maxMessagesPerScan = maxMessagesPerScan;
this.messagesAvailableWatchKey = messagesAvailableWatchKey;
this.terminateOnQueueEmpty = messagesAvailableWatchKey == null;
@@ -116,8 +121,16 @@ class FoundationDbMessagePublisher {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan) {
return new FoundationDbMessagePublisher(beginKeyInclusive, endKeyExclusive, database, maxMessagesPerScan, null, null);
return new FoundationDbMessagePublisher(beginKeyInclusive,
endKeyExclusive,
database,
messageGuidCodec,
maxMessagesPerScan,
null,
null);
}
/// Creates a [FoundationDbMessagePublisher] that publishes a non-terminating stream of messages from a device queue.
@@ -127,9 +140,17 @@ class FoundationDbMessagePublisher {
final KeySelector beginKeyInclusive,
final KeySelector endKeyExclusive,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan,
final byte[] messagesAvailableWatchKey) {
return new FoundationDbMessagePublisher(beginKeyInclusive, endKeyExclusive, database, maxMessagesPerScan, messagesAvailableWatchKey, null);
return new FoundationDbMessagePublisher(beginKeyInclusive,
endKeyExclusive,
database,
messageGuidCodec,
maxMessagesPerScan,
messagesAvailableWatchKey,
null);
}
private synchronized void setState(final State newState, final Event event) {
@@ -222,7 +243,7 @@ class FoundationDbMessagePublisher {
///
/// @return a future of a list of [MessageStreamEntry] with a max size of [#maxMessagesPerScan]
private CompletableFuture<List<MessageStreamEntry.Envelope>> getMessagesBatch() {
return database.runAsync(transaction -> getItemsInRange(transaction, beginKeyCursor, endKeyExclusive, maxMessagesPerScan)
return database.runAsync(transaction -> getItemsInRange(transaction, messageGuidCodec, beginKeyCursor, endKeyExclusive, maxMessagesPerScan)
.thenApply(lastKeyReadAndItems -> {
// Set our beginning key to just past the last key read so that we're ready for our next fetch
lastKeyReadAndItems.first().ifPresent(lastKeyRead -> beginKeyCursor = KeySelector.firstGreaterThan(lastKeyRead));
@@ -248,6 +269,7 @@ class FoundationDbMessagePublisher {
/// @return the last key read (if there were non-zero number of messages read) and the list of messages read
private static CompletableFuture<Pair<Optional<byte[]>, List<MessageStreamEntry.Envelope>>> getItemsInRange(
final Transaction transaction,
final MessageGuidCodec messageGuidCodec,
final KeySelector beginInclusive,
final KeySelector endExclusive,
final int maxMessagesPerScan) {
@@ -259,7 +281,10 @@ class FoundationDbMessagePublisher {
final List<MessageStreamEntry.Envelope> messages = keyValues.stream()
.map(keyValue -> {
try {
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue()));
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue())
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey()))))
.build());
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}

View File

@@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.util.Conversions;
public class FoundationDbMessageStore {
private final Database[] databases;
private final VersionstampUUIDCipher versionstampUUIDCipher;
private final Executor executor;
private final Clock clock;
@@ -60,8 +61,13 @@ public class FoundationDbMessageStore {
public record InsertResult(Optional<Versionstamp> versionstamp, boolean present) {
}
public FoundationDbMessageStore(final Database[] databases, final Executor executor, final Clock clock) {
public FoundationDbMessageStore(final Database[] databases,
final VersionstampUUIDCipher versionstampUUIDCipher,
final Executor executor,
final Clock clock) {
this.databases = databases;
this.versionstampUUIDCipher = versionstampUUIDCipher;
this.executor = executor;
this.clock = clock;
}
@@ -263,9 +269,14 @@ public class FoundationDbMessageStore {
return new FoundationDbMessageStream(getDeviceQueueSubspace(aci, destinationDevice.getId()),
getMessagesAvailableWatchKey(aci),
getShardForAci(aci),
new MessageGuidCodec(aci.uuid(), destinationDevice.getId(), versionstampUUIDCipher),
maxMessagesPerScan);
}
static Versionstamp getVersionstamp(final byte[] messageKey) {
return Tuple.fromBytes(messageKey).getVersionstamp(4);
}
@VisibleForTesting
Database getShardForAci(final AciServiceIdentifier aci) {
return databases[hashAciToShardNumber(aci)];
@@ -278,20 +289,20 @@ public class FoundationDbMessageStore {
}
@VisibleForTesting
Subspace getDeviceQueueSubspace(final AciServiceIdentifier aci, final byte deviceId) {
static Subspace getDeviceQueueSubspace(final AciServiceIdentifier aci, final byte deviceId) {
return getDeviceSubspace(aci, deviceId).get("Q");
}
private Subspace getDeviceSubspace(final AciServiceIdentifier aci, final byte deviceId) {
private static Subspace getDeviceSubspace(final AciServiceIdentifier aci, final byte deviceId) {
return getAccountSubspace(aci).get(deviceId);
}
private Subspace getAccountSubspace(final AciServiceIdentifier aci) {
private static Subspace getAccountSubspace(final AciServiceIdentifier aci) {
return MESSAGES_SUBSPACE.get(aci.uuid());
}
@VisibleForTesting
byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) {
static byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) {
return getAccountSubspace(aci).pack("l");
}

View File

@@ -21,17 +21,23 @@ public class FoundationDbMessageStream implements MessageStream {
private final Subspace deviceQueueSubspace;
private final byte[] messagesAvailableWatchKey;
private final Database database;
private final MessageGuidCodec messageGuidCodec;
/// The maximum number of messages we will fetch per range query operation to avoid excessive memory consumption
private final int maxMessagesPerScan;
private final Flow.Publisher<MessageStreamEntry> messageStreamPublisher;
static final int DEFAULT_MAX_MESSAGES_PER_SCAN = 1024;
FoundationDbMessageStream(final Subspace deviceQueueSubspace, final byte[] messagesAvailableWatchKey,
final Database database, final int maxMessagesPerScan) {
FoundationDbMessageStream(final Subspace deviceQueueSubspace,
final byte[] messagesAvailableWatchKey,
final Database database,
final MessageGuidCodec messageGuidCodec,
final int maxMessagesPerScan) {
this.deviceQueueSubspace = deviceQueueSubspace;
this.messagesAvailableWatchKey = messagesAvailableWatchKey;
this.database = database;
this.messageGuidCodec = messageGuidCodec;
this.maxMessagesPerScan = maxMessagesPerScan;
this.messageStreamPublisher = JdkFlowAdapter.publisherToFlowPublisher(createMessagePublisher());
}
@@ -59,13 +65,13 @@ public class FoundationDbMessageStream implements MessageStream {
final Flux<MessageStreamEntry.Envelope> finitePublisher = maybeEndOfQueueKeyExclusive
.map(endOfQueueKeyExclusive -> FoundationDbMessagePublisher.createFinitePublisher(
KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin),
endOfQueueKeyExclusive, database, maxMessagesPerScan).getMessages())
endOfQueueKeyExclusive, database, messageGuidCodec, maxMessagesPerScan).getMessages())
.orElseGet(Flux::empty);
final KeySelector infinitePublisherBeginKey = maybeEndOfQueueKeyExclusive.orElseGet(
() -> KeySelector.firstGreaterOrEqual(deviceQueueSubspace.range().begin));
final Flux<MessageStreamEntry.Envelope> infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher(
infinitePublisherBeginKey, KeySelector.firstGreaterThan(deviceQueueSubspace.range().end),
database, maxMessagesPerScan, messagesAvailableWatchKey).getMessages();
database, messageGuidCodec, maxMessagesPerScan, messagesAvailableWatchKey).getMessages();
return Flux.concat(
finitePublisher,
Mono.just(new MessageStreamEntry.QueueEmpty()),

View File

@@ -0,0 +1,33 @@
/*
* Copyright 2026 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage.foundationdb;
import com.apple.foundationdb.tuple.Versionstamp;
import java.util.UUID;
class MessageGuidCodec {
private final UUID accountIdentifier;
private final byte deviceId;
private final VersionstampUUIDCipher versionstampUUIDCipher;
MessageGuidCodec(final UUID accountIdentifier,
final byte deviceId,
final VersionstampUUIDCipher versionstampUUIDCipher) {
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.versionstampUUIDCipher = versionstampUUIDCipher;
}
public UUID encodeMessageGuid(final Versionstamp versionstamp) {
return versionstampUUIDCipher.encryptVersionstamp(versionstamp, accountIdentifier, deviceId);
}
public Versionstamp decodeMessageGuid(final UUID messageGuid) {
return versionstampUUIDCipher.decryptVersionstamp(messageGuid, accountIdentifier, deviceId);
}
}

View File

@@ -16,36 +16,58 @@ import com.apple.foundationdb.Range;
import com.apple.foundationdb.StreamingMode;
import com.apple.foundationdb.Transaction;
import com.apple.foundationdb.async.AsyncIterable;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.test.StepVerifier;
/// NOTE: most of the happy-path test cases are already covered in {@link FoundationDbMessageStoreTest}, this test
/// mostly exercises edge-cases and error handling that are hard to test without mocks
class FoundationDbMessagePublisherTest {
private final Range subspaceRange = new Range(new byte[]{(byte) 0}, new byte[]{(byte) 100});
private final byte[] messagesAvailableWatchKey = new byte[]{(byte) 42};
private Database database;
private MessageGuidCodec messageGuidCodec;
private List<FoundationDbMessagePublisher.State> stateTransitions;
private static final AciServiceIdentifier SERVICE_IDENTIFIER = new AciServiceIdentifier(UUID.randomUUID());
private static final Range SUBSPACE_RANGE =
FoundationDbMessageStore.getDeviceQueueSubspace(SERVICE_IDENTIFIER, Device.PRIMARY_ID).range();
private static final byte[] MESSAGES_AVAILABLE_WATCH_KEY =
FoundationDbMessageStore.getMessagesAvailableWatchKey(SERVICE_IDENTIFIER);
@BeforeEach
void setUp() {
database = mock(Database.class);
stateTransitions = new ArrayList<>();
final byte[] messageGuidCodecKey = new byte[16];
new SecureRandom().nextBytes(messageGuidCodecKey);
messageGuidCodec = new MessageGuidCodec(SERVICE_IDENTIFIER.uuid(),
Device.PRIMARY_ID,
new VersionstampUUIDCipher(0, messageGuidCodecKey));
}
@Test
void finitePublisherMultipleBatches() {
void finitePublisherMultipleBatches() throws InvalidProtocolBufferException {
final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message3 = FoundationDbMessageStoreTest.generateRandomMessage(false);
@@ -74,18 +96,19 @@ class FoundationDbMessagePublisherTest {
});
final FoundationDbMessagePublisher finitePublisher = new FoundationDbMessagePublisher(
KeySelector.firstGreaterOrEqual(subspaceRange.begin),
KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.end),
database,
messageGuidCodec,
2, // With 3 messages and batch size set to 2, we'll need to grab 2 batches.
null,
(oldState, newState) -> stateTransitions.add(newState)
(_, newState) -> stateTransitions.add(newState)
);
StepVerifier.create(finitePublisher.getMessages())
.expectNext(new MessageStreamEntry.Envelope(message1))
.expectNext(new MessageStreamEntry.Envelope(message2))
.expectNext(new MessageStreamEntry.Envelope(message3))
.expectNext(getExpectedMessageStreamEntry(keyValue1))
.expectNext(getExpectedMessageStreamEntry(keyValue2))
.expectNext(getExpectedMessageStreamEntry(keyValue3))
.verifyComplete();
assertEquals(List.of(
@@ -102,7 +125,7 @@ class FoundationDbMessagePublisherTest {
@Test
@SuppressWarnings({"unchecked", "resource"})
void infinitePublisher() {
void infinitePublisher() throws InvalidProtocolBufferException {
final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message3 = FoundationDbMessageStoreTest.generateRandomMessage(false);
@@ -138,17 +161,18 @@ class FoundationDbMessagePublisherTest {
final CompletableFuture<Void> watchFuture1 = new CompletableFuture<>();
final CompletableFuture<Void> watchFuture2 = new CompletableFuture<>();
final CompletableFuture<Void> watchFuture3 = new CompletableFuture<>(); // this one will not be completed
when(transaction.watch(messagesAvailableWatchKey))
when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY))
.thenReturn(watchFuture1)
.thenReturn(watchFuture2)
.thenReturn(watchFuture3);
final FoundationDbMessagePublisher infinitePublisher = new FoundationDbMessagePublisher(
KeySelector.firstGreaterOrEqual(subspaceRange.begin),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}),
database,
messageGuidCodec,
2,
messagesAvailableWatchKey,
MESSAGES_AVAILABLE_WATCH_KEY,
(oldState, newState) -> {
stateTransitions.add(newState);
if (newState == FoundationDbMessagePublisher.State.AWAITING_NEW_MESSAGES) {
@@ -167,9 +191,9 @@ class FoundationDbMessagePublisherTest {
);
StepVerifier.create(infinitePublisher.getMessages())
.expectNext(new MessageStreamEntry.Envelope(message1))
.expectNext(new MessageStreamEntry.Envelope(message2))
.expectNext(new MessageStreamEntry.Envelope(message3))
.expectNext(getExpectedMessageStreamEntry(keyValue1))
.expectNext(getExpectedMessageStreamEntry(keyValue2))
.expectNext(getExpectedMessageStreamEntry(keyValue3))
.verifyTimeout(Duration.ofSeconds(1));
assertEquals(List.of(
@@ -192,7 +216,7 @@ class FoundationDbMessagePublisherTest {
}
@Test
void messageAvailableWatchSignalBuffered() {
void messageAvailableWatchSignalBuffered() throws InvalidProtocolBufferException {
final MessageProtos.Envelope message1 = FoundationDbMessageStoreTest.generateRandomMessage(false);
final MessageProtos.Envelope message2 = FoundationDbMessageStoreTest.generateRandomMessage(false);
@@ -220,16 +244,17 @@ class FoundationDbMessagePublisherTest {
final CompletableFuture<Void> watchFuture1 = new CompletableFuture<>();
final CompletableFuture<Void> watchFuture2 = new CompletableFuture<>(); // this one will not be completed
when(transaction.watch(messagesAvailableWatchKey))
when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY))
.thenReturn(watchFuture1)
.thenReturn(watchFuture2);
final FoundationDbMessagePublisher infinitePublisher = new FoundationDbMessagePublisher(
KeySelector.firstGreaterOrEqual(subspaceRange.begin),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterOrEqual(new byte[]{(byte) 10}),
database,
messageGuidCodec,
2,
messagesAvailableWatchKey,
MESSAGES_AVAILABLE_WATCH_KEY,
(oldState, newState) -> {
stateTransitions.add(newState);
// Simulate an edge case where the messages available watch could trigger right after queue empty, but before
@@ -246,8 +271,8 @@ class FoundationDbMessagePublisherTest {
);
StepVerifier.create(infinitePublisher.getMessages())
.expectNext(new MessageStreamEntry.Envelope(message1))
.expectNext(new MessageStreamEntry.Envelope(message2))
.expectNext(getExpectedMessageStreamEntry(keyValue1))
.expectNext(getExpectedMessageStreamEntry(keyValue2))
.verifyTimeout(Duration.ofSeconds(1));
assertEquals(List.of(
@@ -268,25 +293,24 @@ class FoundationDbMessagePublisherTest {
@Test
@SuppressWarnings({"unchecked", "resource"})
void watchCanceledOnSubscriptionCancel() {
void watchCanceledOnSubscriptionCancel() throws InvalidProtocolBufferException {
final FoundationDbMessagePublisher infinitePublisher = FoundationDbMessagePublisher.createInfinitePublisher(
KeySelector.firstGreaterOrEqual(subspaceRange.begin),
KeySelector.firstGreaterThan(subspaceRange.end),
KeySelector.firstGreaterOrEqual(SUBSPACE_RANGE.begin),
KeySelector.firstGreaterThan(SUBSPACE_RANGE.end),
database,
messageGuidCodec,
100,
messagesAvailableWatchKey);
MESSAGES_AVAILABLE_WATCH_KEY);
final MessageProtos.Envelope message = FoundationDbMessageStoreTest.generateRandomMessage(false);
final Transaction transaction = mock(Transaction.class);
final KeyValue keyValue = mock(KeyValue.class);
when(keyValue.getKey()).thenReturn(new byte[]{(byte) 5});
when(keyValue.getValue()).thenReturn(message.toByteArray());
final KeyValue keyValue = mockKeyValue((byte) 5, message);
final AsyncIterable<KeyValue> asyncIterable = mock(AsyncIterable.class);
when(asyncIterable.asList()).thenReturn(CompletableFuture.completedFuture(List.of(keyValue)));
when(transaction.getRange(any(KeySelector.class), any(KeySelector.class), anyInt(), anyBoolean(), any(
StreamingMode.class)))
.thenReturn(asyncIterable);
final CompletableFuture<Void> watchFuture = mock(CompletableFuture.class);
when(transaction.watch(messagesAvailableWatchKey)).thenReturn(watchFuture);
when(transaction.watch(MESSAGES_AVAILABLE_WATCH_KEY)).thenReturn(watchFuture);
when(database.runAsync(any(Function.class))).thenAnswer(
(Answer<CompletableFuture<List<? extends MessageStreamEntry>>>) invocationOnMock -> {
@@ -295,16 +319,29 @@ class FoundationDbMessagePublisherTest {
return f.apply(transaction);
});
StepVerifier.create(infinitePublisher.getMessages())
.expectNext(new MessageStreamEntry.Envelope(message))
.expectNext(getExpectedMessageStreamEntry(keyValue))
.thenCancel()
.verify(Duration.ofMillis(100));
verify(watchFuture).cancel(true);
}
private KeyValue mockKeyValue(final byte key, final MessageProtos.Envelope message) {
private MessageStreamEntry.Envelope getExpectedMessageStreamEntry(final KeyValue keyValue)
throws InvalidProtocolBufferException {
return new MessageStreamEntry.Envelope(MessageProtos.Envelope.parseFrom(keyValue.getValue())
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(FoundationDbMessageStore.getVersionstamp(keyValue.getKey()))))
.build());
}
private static KeyValue mockKeyValue(final byte key, final MessageProtos.Envelope message) {
final ByteBuffer versionstampBuffer = ByteBuffer.allocate(Versionstamp.LENGTH);
versionstampBuffer.put(11, key);
final KeyValue keyValue = mock(KeyValue.class);
when(keyValue.getKey()).thenReturn(new byte[]{key});
when(keyValue.getKey())
.thenReturn(FoundationDbMessageStore.getDeviceQueueSubspace(SERVICE_IDENTIFIER, Device.PRIMARY_ID)
.pack(Tuple.from(Versionstamp.fromBytes(versionstampBuffer.array()))));
when(keyValue.getValue()).thenReturn(message.toByteArray());
return keyValue;
}

View File

@@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.storage.foundationdb;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -13,11 +14,14 @@ import com.apple.foundationdb.tuple.Tuple;
import com.apple.foundationdb.tuple.Versionstamp;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.util.DataSize;
import java.io.UncheckedIOException;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -30,11 +34,11 @@ import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import io.dropwizard.util.DataSize;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
@@ -51,6 +55,7 @@ import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.adapter.JdkFlowAdapter;
import reactor.test.StepVerifier;
@@ -60,14 +65,21 @@ class FoundationDbMessageStoreTest {
@RegisterExtension
static FoundationDbClusterExtension FOUNDATION_DB_EXTENSION = new FoundationDbClusterExtension(2);
private VersionstampUUIDCipher versionstampUUIDCipher;
private FoundationDbMessageStore foundationDbMessageStore;
private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC"));
@BeforeEach
void setup() {
final byte[] versionstampCipherKey = new byte[16];
new SecureRandom().nextBytes(versionstampCipherKey);
versionstampUUIDCipher = new VersionstampUUIDCipher(0, versionstampCipherKey);
foundationDbMessageStore = new FoundationDbMessageStore(
FOUNDATION_DB_EXTENSION.getDatabases(),
versionstampUUIDCipher,
Executors.newVirtualThreadPerTaskExecutor(),
CLOCK);
}
@@ -382,18 +394,34 @@ class FoundationDbMessageStoreTest {
void getMessages(final int numMessages, final int batchSize) {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
writePresenceKey(aci, Device.PRIMARY_ID, 1, 5L);
for (int i = 0; i < numMessages; i++) {
final MessageProtos.Envelope message = generateRandomMessage(false);
assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message)).join());
}
final List<Versionstamp> expectedVersionstamps = IntStream.range(0, numMessages)
.mapToObj(_ -> foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join()
.get(Device.PRIMARY_ID)
.versionstamp()
.orElseThrow())
.toList();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device, batchSize);
final List<MessageStreamEntry> retrievedEntries = new ArrayList<>();
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.recordWith(() -> retrievedEntries)
.expectNextCount(numMessages)
.expectNext(new MessageStreamEntry.QueueEmpty())
.verifyTimeout(Duration.ofSeconds(1));
final MessageGuidCodec messageGuidCodec =
new MessageGuidCodec(aci.uuid(), Device.PRIMARY_ID, versionstampUUIDCipher);
for (int i = 0; i < expectedVersionstamps.size(); i++) {
final MessageStreamEntry.Envelope envelopeEntry =
assertInstanceOf(MessageStreamEntry.Envelope.class, retrievedEntries.get(i));
assertEquals(expectedVersionstamps.get(i),
messageGuidCodec.decodeMessageGuid(UUIDUtil.fromByteString(envelopeEntry.message().getServerGuidBinary())));
}
}
static Stream<Arguments> getMessages() {
@@ -410,20 +438,38 @@ class FoundationDbMessageStoreTest {
@Test
void getMessagesPublishMoreAfterQueueEmpty() {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final MessageGuidCodec messageGuidCodec =
new MessageGuidCodec(aci.uuid(), Device.PRIMARY_ID, versionstampUUIDCipher);
writePresenceKey(aci, Device.PRIMARY_ID, 1, 5L);
final MessageProtos.Envelope message1 = generateRandomMessage(false);
assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message1)).join());
final Versionstamp versionstamp1 = foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message1)).join()
.get(Device.PRIMARY_ID)
.versionstamp()
.orElseThrow();
final MessageProtos.Envelope message2 = generateRandomMessage(false);
assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message2)).join());
final Versionstamp versionstamp2 = foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message2)).join()
.get(Device.PRIMARY_ID)
.versionstamp()
.orElseThrow();
final CountDownLatch latch = new CountDownLatch(1);
final MessageProtos.Envelope message3 = generateRandomMessage(false);
final AtomicReference<Versionstamp> versionstamp3 = new AtomicReference<>();
Thread.ofVirtual().start(() -> {
try {
// Wait until queue is empty
assertTrue(latch.await(1000, TimeUnit.MILLISECONDS));
// Then publish more messages
assertNotNull(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message3)).join());
synchronized (versionstamp3) {
versionstamp3.set(foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, message3)).join()
.get(Device.PRIMARY_ID)
.versionstamp()
.orElseThrow());
versionstamp3.notifyAll();
}
} catch (final InterruptedException e) {
fail(e);
}
@@ -433,11 +479,34 @@ class FoundationDbMessageStoreTest {
device.setId(Device.PRIMARY_ID);
final MessageStream messageStream = foundationDbMessageStore.getMessages(aci, device);
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()))
.expectNext(new MessageStreamEntry.Envelope(message1))
.expectNext(new MessageStreamEntry.Envelope(message2))
.expectNext(new MessageStreamEntry.Envelope(message1
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp1)))
.build()))
.expectNext(new MessageStreamEntry.Envelope(message2
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp2)))
.build()))
.expectNext(new MessageStreamEntry.QueueEmpty())
.then(latch::countDown)
.expectNext(new MessageStreamEntry.Envelope(message3))
.then(() -> {
// Trigger insertion of another message
latch.countDown();
// …but then wait for its versionstamp so we can verify that we have the right payload
synchronized (versionstamp3) {
while (versionstamp3.get() == null) {
try {
versionstamp3.wait();
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
}
}
})
.expectNextMatches(entry -> entry.equals(new MessageStreamEntry.Envelope(message3
.toBuilder()
.setServerGuidBinary(UUIDUtil.toByteString(messageGuidCodec.encodeMessageGuid(versionstamp3.get())))
.build())))
.verifyTimeout(Duration.ofSeconds(3));
}