diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dd6a9d4a9..2a568d8de 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,11 +13,16 @@ jobs: timeout-minutes: 20 services: - foundationdb: + foundationdb0: # Note: this should generally match the version of the FoundationDB SERVER deployed in production; it's okay if # it's a little behind the CLIENT version. image: foundationdb/foundationdb:7.3.62 - options: --name foundationdb + options: --name foundationdb0 + foundationdb1: + # Note: this should generally match the version of the FoundationDB SERVER deployed in production; it's okay if + # it's a little behind the CLIENT version. + image: foundationdb/foundationdb:7.3.62 + options: --name foundationdb1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -49,12 +54,14 @@ jobs: # ca-certificates: required for AWS CRT client apt update && apt install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin ca-certificates - - name: Configure FoundationDB database - run: docker exec foundationdb /usr/bin/fdbcli --exec 'configure new single memory' + - name: Configure FoundationDB0 database + run: docker exec foundationdb0 /usr/bin/fdbcli --exec 'configure new single memory' + - name: Configure FoundationDB1 database + run: docker exec foundationdb1 /usr/bin/fdbcli --exec 'configure new single memory' - name: Download and install FoundationDB client run: | ./mvnw -e -B -Pexclude-spam-filter clean prepare-package -DskipTests=true cp service/target/jib-extra/usr/lib/libfdb_c.x86_64.so /usr/lib/libfdb_c.x86_64.so ldconfig - name: Build with Maven - run: ./mvnw -e -B clean verify -DfoundationDb.serviceContainerName=foundationdb + run: ./mvnw -e -B clean verify -DfoundationDb.serviceContainerNamePrefix=foundationdb diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index 1d20e104c..8a2870b9a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -8,17 +8,21 @@ import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hashing; +import io.dropwizard.util.DataSize; import java.time.Clock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; +import java.util.stream.Collectors; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.util.Conversions; -import org.whispersystems.textsecuregcm.util.Pair; /// An implementation of a message store backed by FoundationDB. /// @@ -38,7 +42,21 @@ public class FoundationDbMessageStore { private final Clock clock; private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M")); - private static final int MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE = 300; + private static final Duration PRESENCE_STALE_THRESHOLD = Duration.ofMinutes(5); + + /// The (approximate) transaction size beyond which we do not add more messages in a transaction. The estimated size + /// includes only message payloads (and not key reads/writes) which we assume will dominate the total + /// transaction size. Note that the FDB [docs](https://apple.github.io/foundationdb/known-limitations.html) currently + /// suggest a limit of 1MB to avoid performance issues, although the hard limit is 10MB + private static final long MAX_MESSAGE_CHUNK_SIZE = DataSize.megabytes(1).toBytes(); + + /// Result of inserting a message for a particular device + /// + /// @param versionstamp the versionstamp of the transaction in which this device's message was inserted, empty + /// otherwise + /// @param present whether the device is online + public record InsertResult(Optional versionstamp, boolean present) { + } public FoundationDbMessageStore(final Database[] databases, final Executor executor, final Clock clock) { this.databases = databases; @@ -46,75 +64,195 @@ public class FoundationDbMessageStore { this.clock = clock; } - /// Insert a message bundle for a set of devices belonging to a single recipient. A message may not be inserted if the - /// device is not present (as determined from its presence key) and the message is ephemeral. If all messages in the - /// bundle don't end up being inserted, we won't return a versionstamp since the transaction was read-only. + /// Convenience method for inserting a single recipient message bundle. See [#insert(Map)] for details. /// - /// @param aci destination account identifier - /// @param messagesByDeviceId a map of deviceId => message envelope - /// @return a future that completes with a [Versionstamp] of the committed transaction if at least one message was - /// inserted - public CompletableFuture> insert(final AciServiceIdentifier aci, + /// @param aciServiceIdentifier accountId of the recipient + /// @param messagesByDeviceId a map of message envelopes by deviceId to be inserted + /// @return a future that yields a map deviceId => the presence state and versionstamp of the transaction in which the + /// device's message was inserted (if any) + public CompletableFuture> insert(final AciServiceIdentifier aciServiceIdentifier, final Map messagesByDeviceId) { - // We use Database#runAsync and not Database#run here because the latter would commit the transaction synchronously - // and we would like to avoid any potential blocking in native code that could unexpectedly pin virtual threads. See https://forums.foundationdb.org/t/fdbdatabase-usage-from-java-api/593/2 - // for details. - return getShardForAci(aci).runAsync(transaction -> insert(aci, messagesByDeviceId, transaction) - .thenApply(hasMutations -> { - if (hasMutations) { - return transaction.getVersionstamp(); - } - return CompletableFuture.completedFuture((byte[]) null); - })) - .thenComposeAsync(Function.identity(), executor) - .thenApply(versionstampBytes -> Optional.ofNullable(versionstampBytes).map(Versionstamp::complete)); - } - private CompletableFuture insert(final AciServiceIdentifier aci, - final Map messagesByDeviceId, - final Transaction transaction) { - final List>> messageInsertFutures = messagesByDeviceId.entrySet() - .stream() - .map(e -> { - final byte deviceId = e.getKey(); - final MessageProtos.Envelope message = e.getValue(); - final byte[] presenceKey = getPresenceKey(aci, deviceId); - return transaction.get(presenceKey) - .thenApply(this::isClientPresent) - .thenApply(isPresent -> { - boolean hasMutations = false; - if (isPresent || !message.getEphemeral()) { - final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId); - transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, - deviceQueueSubspace.packWithVersionstamp(Tuple.from( - Versionstamp.incomplete())), message.toByteArray()); - hasMutations = true; - } - return new Pair<>(isPresent, hasMutations); - }); - }) - .toList(); - return CompletableFuture.allOf(messageInsertFutures.toArray(CompletableFuture[]::new)) - .thenApply(_ -> { - final boolean anyClientPresent = messageInsertFutures - .stream() - .anyMatch(future -> future.join().first()); - final boolean hasMutations = messageInsertFutures - .stream() - .anyMatch(future -> future.join().second()); - if (anyClientPresent && hasMutations) { - transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getMessagesAvailableWatchKey(aci), - Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); - } - return hasMutations; + return insert(Map.of(aciServiceIdentifier, messagesByDeviceId)) + .thenApply(resultsByServiceIdentifier -> { + assert resultsByServiceIdentifier.size() == 1; + + return resultsByServiceIdentifier.get(aciServiceIdentifier); }); } - private Database getShardForAci(final AciServiceIdentifier aci) { + /// Insert a multi-recipient message bundle. Destination ACIs are grouped by shard number. Each shard then starts a + /// potentially multi-transaction operation. Messages are inserted in chunks to avoid transaction size limits. + /// + /// @param messagesByServiceIdentifier a map of accountId to message envelopes by deviceId + /// @return a future that yields a map containing the presence states of devices and versionstamps corresponding to + /// committed transactions during this operation + /// + /// @implNote All messages belonging to the same recipient are always committed in the same transaction for + /// simplicity. A message may not be inserted if the device is not present (as determined from its presence key) and + /// the message is ephemeral. If no messages in a transaction end up being inserted, we won't commit it since the + /// transaction was read-only. As such, no corresponding versionstamp is generated. + public CompletableFuture>> insert( + final Map> messagesByServiceIdentifier) { + + if (messagesByServiceIdentifier.entrySet() + .stream() + .anyMatch(entry -> entry.getValue().isEmpty())) { + throw new IllegalArgumentException("One or more message bundles is empty"); + } + + final Map>>> messagesByShardId = + messagesByServiceIdentifier.entrySet().stream() + .collect(Collectors.groupingBy(entry -> hashAciToShardNumber(entry.getKey()))); + + final List>>> chunkFutures = + new ArrayList<>(); + + messagesByShardId.forEach((shardId, messagesForShard) -> { + final Database shard = databases[shardId]; + + int start = 0, current = 0; + int estimatedTransactionSize = 0; + + while (current < messagesForShard.size()) { + estimatedTransactionSize += messagesForShard.get(current).getValue().values() + .stream() + .mapToInt(MessageProtos.Envelope::getSerializedSize) + .sum(); + + if (estimatedTransactionSize > MAX_MESSAGE_CHUNK_SIZE) { + chunkFutures.add(insertChunk(shard, messagesForShard.subList(start, current))); + + start = current; + estimatedTransactionSize = 0; + } else { + current++; + } + } + + assert start < messagesForShard.size(); + chunkFutures.add(insertChunk(shard, messagesForShard.subList(start, messagesForShard.size()))); + }); + + return CompletableFuture.allOf(chunkFutures.toArray(CompletableFuture[]::new)) + .thenApply(_ -> chunkFutures.stream() + .map(CompletableFuture::join) + .reduce(new HashMap<>(), (a, b) -> { + a.putAll(b); + return a; + })); + } + + private CompletableFuture>> insertChunk( + final Database database, + final List>> messagesByAccountIdentifier) { + + final Map>> insertFuturesByAci = new HashMap<>(); + + // In a message bundle (single-recipient or MRM) the ephemerality should be the same for all envelopes, so just get the first. + final boolean ephemeral = messagesByAccountIdentifier.stream() + .findFirst() + .flatMap(entry -> entry.getValue().values().stream().findFirst()) + .map(MessageProtos.Envelope::getEphemeral) + .orElseThrow(() -> new IllegalStateException("One or more bundles is empty")); + + return database.runAsync(transaction -> { + messagesByAccountIdentifier.forEach(entry -> + insertFuturesByAci.put(entry.getKey(), insert(entry.getKey(), entry.getValue(), transaction))); + + return CompletableFuture.allOf(insertFuturesByAci.values().toArray(CompletableFuture[]::new)) + .thenApply(_ -> { + final boolean anyClientPresent = insertFuturesByAci.values() + .stream() + .map(CompletableFuture::join) + .flatMap(presenceByDeviceId -> presenceByDeviceId.values().stream()) + .anyMatch(isPresent -> isPresent); + if (anyClientPresent || !ephemeral) { + return transaction.getVersionstamp() + .thenApply(versionstampBytes -> Optional.of(Versionstamp.complete(versionstampBytes))); + } + return CompletableFuture.completedFuture(Optional.empty()); + }); + }) + .thenCompose(Function.identity()) + .thenApply(maybeVersionstamp -> insertFuturesByAci.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + assert entry.getValue().isDone(); + final Map presenceByDeviceId = entry.getValue().join(); + + return presenceByDeviceId.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, presenceEntry -> { + final Optional insertResultVersionstamp; + if (presenceEntry.getValue() || !ephemeral) { + assert maybeVersionstamp.isPresent(); + insertResultVersionstamp = maybeVersionstamp; + } else { + insertResultVersionstamp = Optional.empty(); + } + return new InsertResult(insertResultVersionstamp, presenceEntry.getValue()); + })); + }))); + } + + /// Insert a message bundle for a single recipient in an ongoing transaction. + /// + /// @implNote A message for a device is not inserted if it is offline and the message is ephemeral. Additionally, the + /// message watch key is updated iff at least one receiving device is present. + /// + /// @param aci accountId of the recipient + /// @param messagesByDeviceId map of destination deviceId => message envelopes + /// @param transaction the ongoing transaction + /// @return a future that yields the presence state of each destination device + private CompletableFuture> insert(final AciServiceIdentifier aci, + final Map messagesByDeviceId, + final Transaction transaction) { + + final Map> messageInsertFuturesByDeviceId = messagesByDeviceId.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> { + final byte deviceId = e.getKey(); + final MessageProtos.Envelope message = e.getValue(); + final byte[] presenceKey = getPresenceKey(aci, deviceId); + + return transaction.get(presenceKey) + .thenApply(this::isClientPresent) + .thenApply(isPresent -> { + if (isPresent || !message.getEphemeral()) { + transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, + getDeviceQueueSubspace(aci, deviceId) + .packWithVersionstamp(Tuple.from(Versionstamp.incomplete())), message.toByteArray()); + } + + return isPresent; + }); + })); + + return CompletableFuture.allOf(messageInsertFuturesByDeviceId.values().toArray(CompletableFuture[]::new)) + .thenApplyAsync(_ -> { + final Map presenceByDeviceId = messageInsertFuturesByDeviceId.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + assert entry.getValue().isDone(); + return entry.getValue().join(); + })); + + final boolean anyClientPresent = presenceByDeviceId.values().stream().anyMatch(present -> present); + + if (anyClientPresent) { + transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getMessagesAvailableWatchKey(aci), + Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); + } + + return presenceByDeviceId; + }, executor); + } + + @VisibleForTesting + Database getShardForAci(final AciServiceIdentifier aci) { return databases[hashAciToShardNumber(aci)]; } - private int hashAciToShardNumber(final AciServiceIdentifier aci) { + @VisibleForTesting + int hashAciToShardNumber(final AciServiceIdentifier aci) { // We use a consistent hash here to reduce the number of key remappings if we increase the number of shards return Hashing.consistentHash(aci.uuid().getLeastSignificantBits(), databases.length); } @@ -139,7 +277,7 @@ public class FoundationDbMessageStore { @VisibleForTesting byte[] getPresenceKey(final AciServiceIdentifier aci, final byte deviceId) { - return getDeviceQueueSubspace(aci, deviceId).pack("p"); + return getDeviceSubspace(aci, deviceId).pack("p"); } @VisibleForTesting @@ -151,7 +289,6 @@ public class FoundationDbMessageStore { // The presence value is a long with the higher order 16 bits containing a server id, and the lower 48 bits // containing the timestamp (seconds since epoch) that the client updates periodically. final long lastSeenSecondsSinceEpoch = presenceValue & 0x0000ffffffffffffL; - return (clock.instant().getEpochSecond() - lastSeenSecondsSinceEpoch) <= MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE; + return (clock.instant().getEpochSecond() - lastSeenSecondsSinceEpoch) <= PRESENCE_STALE_THRESHOLD.toSeconds(); } - } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbClusterExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbClusterExtension.java new file mode 100644 index 000000000..32b5b7994 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbClusterExtension.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.FDB; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; + +import java.io.IOException; + +public class FoundationDbClusterExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource { + + private FoundationDbDatabaseLifecycleManager[] databaseLifecycleManagers; + private Database[] databases; + + public FoundationDbClusterExtension(final int numInstances) { + this.databaseLifecycleManagers = new FoundationDbDatabaseLifecycleManager[numInstances]; + this.databases = new Database[numInstances]; + } + + @Override + public void beforeAll(final ExtensionContext context) throws IOException { + if (databaseLifecycleManagers[0] == null) { + final String serviceContainerNamePrefix = System.getProperty("foundationDb.serviceContainerNamePrefix"); + + for (int i = 0; i < databaseLifecycleManagers.length; i++) { + final FoundationDbDatabaseLifecycleManager databaseLifecycleManager = serviceContainerNamePrefix != null + ? new ServiceContainerFoundationDbDatabaseLifecycleManager(serviceContainerNamePrefix + i) + : new TestcontainersFoundationDbDatabaseLifecycleManager(); + databaseLifecycleManager.initializeDatabase(FDB.selectAPIVersion(FoundationDbVersion.getFoundationDbApiVersion())); + databaseLifecycleManagers[i] = databaseLifecycleManager; + databases[i] = databaseLifecycleManager.getDatabase(); + } + + } + } + + public Database[] getDatabases() { + return databases; + } + + @Override + public void close() throws Throwable { + if (databaseLifecycleManagers[0] != null) { + for (final FoundationDbDatabaseLifecycleManager databaseLifecycleManager : databaseLifecycleManagers) { + databaseLifecycleManager.closeDatabase(); + } + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java deleted file mode 100644 index 5bab45d9a..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import com.apple.foundationdb.Database; -import com.apple.foundationdb.FDB; -import java.io.IOException; -import org.junit.jupiter.api.extension.BeforeAllCallback; -import org.junit.jupiter.api.extension.ExtensionContext; - -public class FoundationDbExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource { - - private static FoundationDbDatabaseLifecycleManager databaseLifecycleManager; - - @Override - public void beforeAll(final ExtensionContext context) throws IOException { - if (databaseLifecycleManager == null) { - final String serviceContainerName = System.getProperty("foundationDb.serviceContainerName"); - - databaseLifecycleManager = serviceContainerName != null - ? new ServiceContainerFoundationDbDatabaseLifecycleManager(serviceContainerName) - : new TestcontainersFoundationDbDatabaseLifecycleManager(); - - databaseLifecycleManager.initializeDatabase(FDB.selectAPIVersion(FoundationDbVersion.getFoundationDbApiVersion())); - - context.getRoot().getStore(ExtensionContext.Namespace.GLOBAL).put(getClass().getName(), this); - } - } - - public Database getDatabase() { - return databaseLifecycleManager.getDatabase(); - } - - @Override - public void close() throws Throwable { - if (databaseLifecycleManager != null) { - databaseLifecycleManager.closeDatabase(); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbTest.java deleted file mode 100644 index f19d509bb..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbTest.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; -import java.nio.charset.StandardCharsets; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; - -public class FoundationDbTest { - - @RegisterExtension - static FoundationDbExtension FOUNDATION_DB_EXTENSION = new FoundationDbExtension(); - - @Test - void setGetValue() { - final byte[] key = "test".getBytes(StandardCharsets.UTF_8); - final byte[] value = TestRandomUtil.nextBytes(16); - - FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> { - transaction.set(key, value); - return null; - }); - - final byte[] retrievedValue = FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> transaction.get(key).join()); - - assertArrayEquals(value, retrievedValue); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index a40a7f660..33d9abcf6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -3,9 +3,11 @@ package org.whispersystems.textsecuregcm.storage.foundationdb; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import com.apple.foundationdb.Database; +import com.apple.foundationdb.KeyValue; +import com.apple.foundationdb.async.AsyncUtil; import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.protobuf.ByteString; @@ -14,16 +16,21 @@ import java.io.UncheckedIOException; import java.time.Clock; import java.time.Instant; import java.time.ZoneId; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.Executors; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import io.dropwizard.util.DataSize; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; @@ -35,7 +42,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.FoundationDbExtension; +import org.whispersystems.textsecuregcm.storage.FoundationDbClusterExtension; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @@ -43,7 +50,7 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil; class FoundationDbMessageStoreTest { @RegisterExtension - static FoundationDbExtension FOUNDATION_DB_EXTENSION = new FoundationDbExtension(); + static FoundationDbClusterExtension FOUNDATION_DB_EXTENSION = new FoundationDbClusterExtension(2); private FoundationDbMessageStore foundationDbMessageStore; @@ -52,7 +59,7 @@ class FoundationDbMessageStoreTest { @BeforeEach void setup() { foundationDbMessageStore = new FoundationDbMessageStore( - new Database[]{FOUNDATION_DB_EXTENSION.getDatabase()}, + FOUNDATION_DB_EXTENSION.getDatabases(), Executors.newVirtualThreadPerTaskExecutor(), CLOCK); } @@ -60,7 +67,7 @@ class FoundationDbMessageStoreTest { @ParameterizedTest @MethodSource void insert(final long presenceUpdatedBeforeSeconds, final boolean ephemeral, final boolean expectMessagesInserted, - final boolean expectVersionstampUpdated) { + final boolean expectVersionstampUpdated, final boolean expectPresenceState) { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final List deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6) .mapToObj(i -> (byte) i) @@ -68,15 +75,19 @@ class FoundationDbMessageStoreTest { deviceIds.forEach(deviceId -> writePresenceKey(aci, deviceId, 1, presenceUpdatedBeforeSeconds)); final Map messagesByDeviceId = deviceIds.stream() .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral))); - final Optional versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); - assertNotNull(versionstamp); + final Map result = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + assertNotNull(result); + final Optional returnedVersionstamp = result.values().stream().findFirst() + .flatMap(FoundationDbMessageStore.InsertResult::versionstamp); if (expectMessagesInserted) { - assertTrue(versionstamp.isPresent()); + assertTrue(returnedVersionstamp.isPresent()); + assertTrue(result.values().stream().allMatch(insertResult -> returnedVersionstamp.equals(insertResult.versionstamp()))); final Map storedMessagesByDeviceId = deviceIds.stream() .collect(Collectors.toMap(Function.identity(), deviceId -> { try { - return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp.get())); + return MessageProtos.Envelope.parseFrom( + getMessageByVersionstamp(aci, deviceId, returnedVersionstamp.get())); } catch (final InvalidProtocolBufferException e) { throw new UncheckedIOException(e); } @@ -84,28 +95,32 @@ class FoundationDbMessageStoreTest { assertEquals(messagesByDeviceId, storedMessagesByDeviceId); } else { - assertTrue(versionstamp.isEmpty()); + assertTrue(result.values().stream().allMatch(insertResult -> insertResult.versionstamp().isEmpty())); } if (expectVersionstampUpdated) { - assertEquals(versionstamp, getMessagesAvailableWatch(aci), + final Optional messagesAvailableWatchVersionstamp = getMessagesAvailableWatch(aci); + assertTrue(messagesAvailableWatchVersionstamp.isPresent()); + assertEquals(returnedVersionstamp, messagesAvailableWatchVersionstamp, "messages available versionstamp should be the versionstamp of the last insert transaction"); } else { assertTrue(getMessagesAvailableWatch(aci).isEmpty()); } + + assertTrue(result.values().stream().allMatch(insertResult -> insertResult.present() == expectPresenceState)); } private static Stream insert() { return Stream.of( Arguments.argumentSet("Non-ephemeral messages with all devices online", - 10L, false, true, true), + 10L, false, true, true, true), Arguments.argumentSet( "Ephemeral messages with presence updated exactly at the second before which the device would be considered offline", - 300L, true, true, true), + 300L, true, true, true, true), Arguments.argumentSet("Non-ephemeral messages for with all devices offline", - 310L, false, true, false), + 310L, false, true, false, false), Arguments.argumentSet("Ephemeral messages with all devices offline", - 310L, true, false, false) + 310L, true, false, false, false) ); } @@ -113,10 +128,15 @@ class FoundationDbMessageStoreTest { void versionstampCorrectlyUpdatedOnMultipleInserts() { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L); - foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); - final Optional secondMessageVersionstamp = foundationDbMessageStore.insert(aci, + foundationDbMessageStore.insert(Map.of(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false)))).join(); + final Map secondMessageInsertResult = foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); - assertEquals(secondMessageVersionstamp, getMessagesAvailableWatch(aci)); + + final Optional messagesAvailableWatchVersionstamp = getMessagesAvailableWatch(aci); + assertTrue(messagesAvailableWatchVersionstamp.isPresent()); + assertEquals( + secondMessageInsertResult.get(Device.PRIMARY_ID).versionstamp(), + messagesAvailableWatchVersionstamp); } @ParameterizedTest @@ -130,24 +150,29 @@ class FoundationDbMessageStoreTest { writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L); final Map messagesByDeviceId = deviceIds.stream() .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral))); - final Optional versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); - assertNotNull(versionstamp); - assertTrue(versionstamp.isPresent(), - "versionstamp should be present since at least one message should be inserted"); + final Map result = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + assertNotNull(result); + final Optional returnedVersionstamp = result.get(Device.PRIMARY_ID).versionstamp(); + assertTrue(returnedVersionstamp.isPresent(), + "versionstamp should be present for online device"); assertArrayEquals( messagesByDeviceId.get(Device.PRIMARY_ID).toByteArray(), - getMessageByVersionstamp(aci, Device.PRIMARY_ID, versionstamp.get()), + getMessageByVersionstamp(aci, Device.PRIMARY_ID, returnedVersionstamp.get()), "Message for primary should always be stored since it has a recently updated presence"); if (ephemeral) { assertTrue(IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID + 6) - .mapToObj(deviceId -> getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get())) + .mapToObj(deviceId -> getMessageByVersionstamp(aci, (byte) deviceId, returnedVersionstamp.get())) .allMatch(Objects::isNull), "Ephemeral messages for non-present devices must not be stored"); + assertTrue(IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID + 6) + .mapToObj(deviceId -> result.get((byte) deviceId).versionstamp()) + .allMatch(Optional::isEmpty), + "Unexpected versionstamp found for one or more devices that didn't have any messages inserted"); } else { IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID) .forEach(deviceId -> { - final byte[] messageBytes = getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get()); + final byte[] messageBytes = getMessageByVersionstamp(aci, (byte) deviceId, returnedVersionstamp.get()); assertEquals(messagesByDeviceId.get((byte) deviceId).toByteArray(), messageBytes, "Non-ephemeral messages must always be stored"); }); @@ -169,23 +194,184 @@ class FoundationDbMessageStoreTest { Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(5))), true), Arguments.argumentSet("Presence updated same second as current time", Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(0))), true), - Arguments.argumentSet("Presence updated exactly at the second before which it would have been considered offline", + Arguments.argumentSet( + "Presence updated exactly at the second before which it would have been considered offline", Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(300))), true), Arguments.argumentSet("Presence expired", Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(400))), false) ); } + /// Represents a cohort of recipients with the same config + record MultiRecipientTestConfig(int shardNum, int numRecipients, boolean devicePresent, + boolean generateEphemeralMessages, boolean expectMessagesInserted) {} + + @ParameterizedTest + @MethodSource + void insertMultiRecipient(final List testConfigs, final DataSize contentSize, + final int[] expectedNumTransactionsByShard) { + // Generate a list of ACIs for each test config + final List> acisByConfig = testConfigs.stream() + .map(testConfig -> IntStream.range(0, testConfig.numRecipients()) + .mapToObj(_ -> generateRandomAciForShard(testConfig.shardNum())) + .toList()) + .toList(); + + // Generate MRM bundles for each ACI, for each test config. Later, we'll assert if the stored messages (if expected) + // are the same as those we generated. + final List>> mrmByConfig = IntStream.range(0, + testConfigs.size()) + .mapToObj(i -> { + final List acis = acisByConfig.get(i); + final MultiRecipientTestConfig testConfig = testConfigs.get(i); + return acis.stream() + .collect(Collectors.toMap( + Function.identity(), + _ -> Map.of(Device.PRIMARY_ID, + generateRandomMessage(testConfig.generateEphemeralMessages(), (int) contentSize.toBytes())))); + + }) + .toList(); + + // Create the consolidated MRM bundle by ACI. + final Map> mrmBundle = new HashMap<>(); + mrmByConfig.forEach(mrmBundle::putAll); + + // Write a presence key for the cohort of recipients if the config indicates that the device must be present. + for (int i = 0; i < testConfigs.size(); i++) { + final List acis = acisByConfig.get(i); + final MultiRecipientTestConfig testConfig = testConfigs.get(i); + if (testConfig.devicePresent()) { + acis.forEach(aci -> writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L)); + } + } + + final Map> result = foundationDbMessageStore.insert(mrmBundle).join(); + assertNotNull(result); + + // Compute the set of versionstamps by shard number from the individual device insert results, so that we can + // assert that each shard has the expected number of committed transactions. + final Map> returnedVersionstampsByShard = new HashMap<>(); + result.forEach((aci, deviceResults) -> { + final int shardNum = foundationDbMessageStore.hashAciToShardNumber(aci); + final Set versionstampSet = returnedVersionstampsByShard.computeIfAbsent(shardNum, _ -> new HashSet<>()); + deviceResults.forEach((_, deviceResult) -> deviceResult.versionstamp().ifPresent(versionstampSet::add)); + }); + + final int[] returnedNumVersionstampsByShard = new int[FOUNDATION_DB_EXTENSION.getDatabases().length]; + for (int i = 0; i < returnedNumVersionstampsByShard.length; i++) { + returnedNumVersionstampsByShard[i] = returnedVersionstampsByShard.getOrDefault(i, Collections.emptySet()).size(); + } + + assertArrayEquals(expectedNumTransactionsByShard, returnedNumVersionstampsByShard); + + // For each cohort of recipients, check whether the stored messages (if expected) are the same as those we inserted + // and whether the returned device presence states are the same as the configured states. + IntStream.range(0, testConfigs.size()).forEach(i -> { + final List acis = acisByConfig.get(i); + final MultiRecipientTestConfig shardConfig = testConfigs.get(i); + if (shardConfig.expectMessagesInserted()) { + final Map> storedMrmBundle = acis.stream() + .collect(Collectors.toMap(Function.identity(), aci -> { + final List items = getItemsInDeviceQueue(aci, Device.PRIMARY_ID); + assertEquals(1, items.size()); + try { + final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(items.getFirst().getValue()); + return Map.of(Device.PRIMARY_ID, envelope); + } catch (final InvalidProtocolBufferException e) { + throw new UncheckedIOException(e); + } + })); + assertEquals(mrmByConfig.get(i), storedMrmBundle, + "Stored message bundle does not match inserted message bundle"); + } else { + assertEquals(0, acis + .stream() + .mapToInt(aci -> getItemsInDeviceQueue(aci, Device.PRIMARY_ID).size()) + .sum(), "Unexpected messages found in device queue"); + } + + assertTrue(acis + .stream() + .allMatch( + aci -> result.get(aci).get(Device.PRIMARY_ID).present() == shardConfig.devicePresent()), + "Device presence state from insert result does not match expected state"); + }); + } + + static Stream insertMultiRecipient() { + return Stream.of( + Arguments.argumentSet("Multiple recipients on a single shard should result in a single transaction", + List.of( + new MultiRecipientTestConfig(0, 5, true, false, true)), + DataSize.bytes(128), new int[] {1, 0}), + Arguments.argumentSet( + "Multiple recipients on a single shard exceeding the transaction limit should be broken up into multiple transactions", + List.of( + new MultiRecipientTestConfig(0, 15, true, false, true)), + DataSize.kilobytes(90), new int[] {2, 0}), + Arguments.argumentSet("Multiple recipients on different shards should result in multiple transactions", + List.of( + new MultiRecipientTestConfig(0, 5, true, false, true), + new MultiRecipientTestConfig(1, 5, true, false, true)), + DataSize.bytes(128), new int[] {1, 1}), + Arguments.argumentSet( + "Multiple recipients on different shards each exceeding the transaction limit should be broken up into multiple transactions on each shard", + List.of( + new MultiRecipientTestConfig(0, 15, true, false, true), + new MultiRecipientTestConfig(1, 15, true, false, true)), + DataSize.kilobytes(90), new int[] {2, 2}), + Arguments.argumentSet( + "Multiple recipients on a single shard with ephemeral messages and no devices present should result in no transactions committed", + List.of( + new MultiRecipientTestConfig(0, 5, false, true, false)), + DataSize.bytes(128), new int[] {0, 0}), + Arguments.argumentSet( + "Multiple recipients on different shards with ephemeral messages and no devices present should result in no transactions committed", + List.of( + new MultiRecipientTestConfig(0, 5, false, true, false), + new MultiRecipientTestConfig(1, 5, false, true, false)), + DataSize.bytes(128), new int[] {0, 0}), + Arguments.argumentSet( + "Multiple recipients on two shards with one shard having no devices present should result in only one transaction", + List.of( + new MultiRecipientTestConfig(0, 5, false, true, false), + new MultiRecipientTestConfig(1, 5, true, true, true)), + DataSize.bytes(128), new int[] {0, 1}), + Arguments.argumentSet( + "Multiple recipients on a single shard with some recipients having no devices present should result in only one transaction", + List.of( + new MultiRecipientTestConfig(0, 3, false, true, false), + new MultiRecipientTestConfig(0, 3, true, true, true)), + DataSize.bytes(128), new int[] {1, 0}), + Arguments.argumentSet( + "Multiple recipients on a single shard with total size just exceeding 2 chunks should result in 3 transactions", + List.of( + new MultiRecipientTestConfig(0, 23, true, false, true)), + DataSize.kilobytes(90), new int[] {3, 0}) + ); + } + + @Test + void insertEmptyBundle() { + assertThrows(IllegalArgumentException.class, () -> foundationDbMessageStore.insert( + Map.of(generateRandomAciForShard(0), Collections.emptyMap()))); + } + private static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral) { + return generateRandomMessage(ephemeral, 16); + } + + private static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral, final int contentSize) { return MessageProtos.Envelope.newBuilder() - .setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(16))) + .setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(contentSize))) .setEphemeral(ephemeral) .build(); } private byte[] getMessageByVersionstamp(final AciServiceIdentifier aci, final byte deviceId, final Versionstamp versionstamp) { - return FOUNDATION_DB_EXTENSION.getDatabase().read(transaction -> { + return foundationDbMessageStore.getShardForAci(aci).read(transaction -> { final byte[] key = foundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId) .pack(Tuple.from(versionstamp)); return transaction.get(key); @@ -193,7 +379,7 @@ class FoundationDbMessageStoreTest { } private Optional getMessagesAvailableWatch(final AciServiceIdentifier aci) { - return FOUNDATION_DB_EXTENSION.getDatabase() + return foundationDbMessageStore.getShardForAci(aci) .read(transaction -> transaction.get(foundationDbMessageStore.getMessagesAvailableWatchKey(aci)) .thenApply(value -> value == null ? null : Tuple.fromBytes(value).getVersionstamp(0)) .thenApply(Optional::ofNullable)) @@ -202,7 +388,7 @@ class FoundationDbMessageStoreTest { private void writePresenceKey(final AciServiceIdentifier aci, final byte deviceId, final int serverId, final long secondsBeforeCurrentTime) { - FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> { + foundationDbMessageStore.getShardForAci(aci).run(transaction -> { final byte[] presenceKey = foundationDbMessageStore.getPresenceKey(aci, deviceId); final long presenceUpdateEpochSeconds = getEpochSecondsBeforeClock(secondsBeforeCurrentTime); final long presenceValue = constructPresenceValue(serverId, presenceUpdateEpochSeconds); @@ -219,4 +405,19 @@ class FoundationDbMessageStoreTest { return (long) (serverId & 0x0ffff) << 48 | (presenceUpdateEpochSeconds & 0x0000ffffffffffffL); } + private AciServiceIdentifier generateRandomAciForShard(final int shardNumber) { + assert shardNumber < FOUNDATION_DB_EXTENSION.getDatabases().length; + while (true) { + final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); + if (foundationDbMessageStore.hashAciToShardNumber(aci) == shardNumber) { + return aci; + } + } + } + + private List getItemsInDeviceQueue(final AciServiceIdentifier aci, final byte deviceId) { + return foundationDbMessageStore.getShardForAci(aci).readAsync(transaction -> AsyncUtil.collect(transaction.getRange( + foundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId).range()))).join(); + } + }