diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java new file mode 100644 index 000000000..562c7ecdb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -0,0 +1,97 @@ +package org.whispersystems.textsecuregcm.storage.foundationdb; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.MutationType; +import com.apple.foundationdb.Transaction; +import com.apple.foundationdb.subspace.Subspace; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.foundationdb.tuple.Versionstamp; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.hash.Hashing; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.function.Function; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; + +/// An implementation of a message store backed by FoundationDB. +/// +/// @implNote The layout of elements in FoundationDB is as follows: +/// * messages +/// * {aci} +/// * last => versionstamp +/// * {deviceId} +/// * queue +/// * {versionstamp_1} => envelope_1 +/// * {versionstamp_2} => envelope_2 +public class FoundationDbMessageStore { + + private final Database[] databases; + private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M")); + private final Executor executor; + + public FoundationDbMessageStore(final Database[] databases, final Executor executor) { + this.databases = databases; + this.executor = executor; + } + + /** + * Insert a message bundle for a set of devices belonging to a single recipient + * + * @param aci destination account identifier + * @param messagesByDeviceId a map of deviceId => message envelope + * @return a future that completes with a {@link Versionstamp} of the committed transaction + */ + public CompletableFuture insert(final AciServiceIdentifier aci, + final Map messagesByDeviceId) { + // We use Database#runAsync and not Database#run here because the latter would commit the transaction synchronously + // and we would like to avoid any potential blocking in native code that could unexpectedly pin virtual threads. See https://forums.foundationdb.org/t/fdbdatabase-usage-from-java-api/593/2 + // for details. + return getShardForAci(aci).runAsync(transaction -> { + insert(aci, messagesByDeviceId, transaction); + return CompletableFuture.completedFuture(transaction.getVersionstamp()); + }) + .thenComposeAsync(Function.identity(), executor) + .thenApply(Versionstamp::complete); + } + + private void insert(final AciServiceIdentifier aci, final Map messagesByDeviceId, + final Transaction transaction) { + messagesByDeviceId.forEach((deviceId, message) -> { + final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId); + transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, deviceQueueSubspace.packWithVersionstamp(Tuple.from( + Versionstamp.incomplete())), message.toByteArray()); + }); + transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getLastMessageKey(aci), + Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); + } + + private Database getShardForAci(final AciServiceIdentifier aci) { + return databases[hashAciToShardNumber(aci)]; + } + + private int hashAciToShardNumber(final AciServiceIdentifier aci) { + // We use a consistent hash here to reduce the number of key remappings if we increase the number of shards + return Hashing.consistentHash(aci.uuid().getLeastSignificantBits(), databases.length); + } + + @VisibleForTesting + Subspace getDeviceQueueSubspace(final AciServiceIdentifier aci, final byte deviceId) { + return getDeviceSubspace(aci, deviceId).get("Q"); + } + + private Subspace getDeviceSubspace(final AciServiceIdentifier aci, final byte deviceId) { + return getAccountSubspace(aci).get(deviceId); + } + + private Subspace getAccountSubspace(final AciServiceIdentifier aci) { + return MESSAGES_SUBSPACE.get(aci.uuid()); + } + + @VisibleForTesting + byte[] getLastMessageKey(final AciServiceIdentifier aci) { + return getAccountSubspace(aci).pack("l"); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java index 4cc352af6..5bab45d9a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/FoundationDbExtension.java @@ -11,7 +11,7 @@ import java.io.IOException; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.ExtensionContext; -class FoundationDbExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource { +public class FoundationDbExtension implements BeforeAllCallback, ExtensionContext.Store.CloseableResource { private static FoundationDbDatabaseLifecycleManager databaseLifecycleManager; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java new file mode 100644 index 000000000..091820203 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -0,0 +1,101 @@ +package org.whispersystems.textsecuregcm.storage.foundationdb; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.apple.foundationdb.Database; +import com.apple.foundationdb.tuple.Tuple; +import com.apple.foundationdb.tuple.Versionstamp; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.Executors; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.FoundationDbExtension; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; + +@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class FoundationDbMessageStoreTest { + + @RegisterExtension + static FoundationDbExtension FOUNDATION_DB_EXTENSION = new FoundationDbExtension(); + + private FoundationDbMessageStore foundationDbMessageStore; + + @BeforeEach + void setup() { + foundationDbMessageStore = new FoundationDbMessageStore( + new Database[]{FOUNDATION_DB_EXTENSION.getDatabase()}, + Executors.newVirtualThreadPerTaskExecutor()); + } + + @Test + void insert() { + final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); + final List deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6) + .mapToObj(i -> (byte) i) + .toList(); + final Map messagesByDeviceId = deviceIds.stream() + .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage())); + final Versionstamp versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + assertNotNull(versionstamp); + + final Map storedMessagesByDeviceId = deviceIds.stream() + .collect(Collectors.toMap(Function.identity(), deviceId -> { + try { + return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp)); + } catch (final InvalidProtocolBufferException e) { + throw new UncheckedIOException(e); + } + })); + + assertEquals(messagesByDeviceId, storedMessagesByDeviceId); + assertEquals(versionstamp, getLastMessageVersionstamp(aci), + "last message versionstamp should be the versionstamp of the last insert transaction"); + } + + @Test + void versionstampCorrectlyUpdatedOnMultipleInserts() { + final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); + foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage())).join(); + final Versionstamp secondMessageVersionstamp = foundationDbMessageStore.insert(aci, + Map.of(Device.PRIMARY_ID, generateRandomMessage())).join(); + assertEquals(secondMessageVersionstamp, getLastMessageVersionstamp(aci)); + } + + private static MessageProtos.Envelope generateRandomMessage() { + return MessageProtos.Envelope.newBuilder() + .setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(16))) + .build(); + } + + private byte[] getMessageByVersionstamp(final AciServiceIdentifier aci, final byte deviceId, + final Versionstamp versionstamp) { + return FOUNDATION_DB_EXTENSION.getDatabase().read(transaction -> { + final byte[] key = foundationDbMessageStore.getDeviceQueueSubspace(aci, deviceId) + .pack(Tuple.from(versionstamp)); + return transaction.get(key); + }).join(); + } + + private Versionstamp getLastMessageVersionstamp(final AciServiceIdentifier aci) { + return FOUNDATION_DB_EXTENSION.getDatabase() + .read(transaction -> transaction.get(foundationDbMessageStore.getLastMessageKey(aci)) + .thenApply(Tuple::fromBytes) + .thenApply(t -> t.getVersionstamp(0))) + .join(); + } + +}