diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index 562c7ecdb..1d20e104c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -8,63 +8,106 @@ import com.apple.foundationdb.tuple.Tuple; import com.apple.foundationdb.tuple.Versionstamp; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hashing; +import java.time.Clock; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.function.Function; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.util.Conversions; +import org.whispersystems.textsecuregcm.util.Pair; /// An implementation of a message store backed by FoundationDB. /// /// @implNote The layout of elements in FoundationDB is as follows: /// * messages /// * {aci} -/// * last => versionstamp +/// * messageAvailableWatch => versionstamp /// * {deviceId} +/// * presence => server_id | last_seen_seconds_since_epoch /// * queue /// * {versionstamp_1} => envelope_1 /// * {versionstamp_2} => envelope_2 public class FoundationDbMessageStore { private final Database[] databases; - private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M")); private final Executor executor; + private final Clock clock; - public FoundationDbMessageStore(final Database[] databases, final Executor executor) { + private static final Subspace MESSAGES_SUBSPACE = new Subspace(Tuple.from("M")); + private static final int MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE = 300; + + public FoundationDbMessageStore(final Database[] databases, final Executor executor, final Clock clock) { this.databases = databases; this.executor = executor; + this.clock = clock; } - /** - * Insert a message bundle for a set of devices belonging to a single recipient - * - * @param aci destination account identifier - * @param messagesByDeviceId a map of deviceId => message envelope - * @return a future that completes with a {@link Versionstamp} of the committed transaction - */ - public CompletableFuture insert(final AciServiceIdentifier aci, + /// Insert a message bundle for a set of devices belonging to a single recipient. A message may not be inserted if the + /// device is not present (as determined from its presence key) and the message is ephemeral. If all messages in the + /// bundle don't end up being inserted, we won't return a versionstamp since the transaction was read-only. + /// + /// @param aci destination account identifier + /// @param messagesByDeviceId a map of deviceId => message envelope + /// @return a future that completes with a [Versionstamp] of the committed transaction if at least one message was + /// inserted + public CompletableFuture> insert(final AciServiceIdentifier aci, final Map messagesByDeviceId) { // We use Database#runAsync and not Database#run here because the latter would commit the transaction synchronously // and we would like to avoid any potential blocking in native code that could unexpectedly pin virtual threads. See https://forums.foundationdb.org/t/fdbdatabase-usage-from-java-api/593/2 // for details. - return getShardForAci(aci).runAsync(transaction -> { - insert(aci, messagesByDeviceId, transaction); - return CompletableFuture.completedFuture(transaction.getVersionstamp()); - }) + return getShardForAci(aci).runAsync(transaction -> insert(aci, messagesByDeviceId, transaction) + .thenApply(hasMutations -> { + if (hasMutations) { + return transaction.getVersionstamp(); + } + return CompletableFuture.completedFuture((byte[]) null); + })) .thenComposeAsync(Function.identity(), executor) - .thenApply(Versionstamp::complete); + .thenApply(versionstampBytes -> Optional.ofNullable(versionstampBytes).map(Versionstamp::complete)); } - private void insert(final AciServiceIdentifier aci, final Map messagesByDeviceId, + private CompletableFuture insert(final AciServiceIdentifier aci, + final Map messagesByDeviceId, final Transaction transaction) { - messagesByDeviceId.forEach((deviceId, message) -> { - final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId); - transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, deviceQueueSubspace.packWithVersionstamp(Tuple.from( - Versionstamp.incomplete())), message.toByteArray()); - }); - transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getLastMessageKey(aci), - Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); + final List>> messageInsertFutures = messagesByDeviceId.entrySet() + .stream() + .map(e -> { + final byte deviceId = e.getKey(); + final MessageProtos.Envelope message = e.getValue(); + final byte[] presenceKey = getPresenceKey(aci, deviceId); + return transaction.get(presenceKey) + .thenApply(this::isClientPresent) + .thenApply(isPresent -> { + boolean hasMutations = false; + if (isPresent || !message.getEphemeral()) { + final Subspace deviceQueueSubspace = getDeviceQueueSubspace(aci, deviceId); + transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, + deviceQueueSubspace.packWithVersionstamp(Tuple.from( + Versionstamp.incomplete())), message.toByteArray()); + hasMutations = true; + } + return new Pair<>(isPresent, hasMutations); + }); + }) + .toList(); + return CompletableFuture.allOf(messageInsertFutures.toArray(CompletableFuture[]::new)) + .thenApply(_ -> { + final boolean anyClientPresent = messageInsertFutures + .stream() + .anyMatch(future -> future.join().first()); + final boolean hasMutations = messageInsertFutures + .stream() + .anyMatch(future -> future.join().second()); + if (anyClientPresent && hasMutations) { + transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getMessagesAvailableWatchKey(aci), + Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); + } + return hasMutations; + }); } private Database getShardForAci(final AciServiceIdentifier aci) { @@ -90,8 +133,25 @@ public class FoundationDbMessageStore { } @VisibleForTesting - byte[] getLastMessageKey(final AciServiceIdentifier aci) { + byte[] getMessagesAvailableWatchKey(final AciServiceIdentifier aci) { return getAccountSubspace(aci).pack("l"); } + @VisibleForTesting + byte[] getPresenceKey(final AciServiceIdentifier aci, final byte deviceId) { + return getDeviceQueueSubspace(aci, deviceId).pack("p"); + } + + @VisibleForTesting + boolean isClientPresent(final byte[] presenceValueBytes) { + if (presenceValueBytes == null) { + return false; + } + final long presenceValue = Conversions.byteArrayToLong(presenceValueBytes); + // The presence value is a long with the higher order 16 bits containing a server id, and the lower 48 bits + // containing the timestamp (seconds since epoch) that the client updates periodically. + final long lastSeenSecondsSinceEpoch = presenceValue & 0x0000ffffffffffffL; + return (clock.instant().getEpochSecond() - lastSeenSecondsSinceEpoch) <= MAX_SECONDS_SINCE_UPDATE_FOR_PRESENCE; + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index 091820203..a40a7f660 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -1,7 +1,9 @@ package org.whispersystems.textsecuregcm.storage.foundationdb; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.apple.foundationdb.Database; import com.apple.foundationdb.tuple.Tuple; @@ -9,21 +11,32 @@ import com.apple.foundationdb.tuple.Versionstamp; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import java.io.UncheckedIOException; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.Executors; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.FoundationDbExtension; +import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @@ -34,50 +47,139 @@ class FoundationDbMessageStoreTest { private FoundationDbMessageStore foundationDbMessageStore; + private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC")); + @BeforeEach void setup() { foundationDbMessageStore = new FoundationDbMessageStore( new Database[]{FOUNDATION_DB_EXTENSION.getDatabase()}, - Executors.newVirtualThreadPerTaskExecutor()); + Executors.newVirtualThreadPerTaskExecutor(), + CLOCK); } - @Test - void insert() { + @ParameterizedTest + @MethodSource + void insert(final long presenceUpdatedBeforeSeconds, final boolean ephemeral, final boolean expectMessagesInserted, + final boolean expectVersionstampUpdated) { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final List deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6) .mapToObj(i -> (byte) i) .toList(); + deviceIds.forEach(deviceId -> writePresenceKey(aci, deviceId, 1, presenceUpdatedBeforeSeconds)); final Map messagesByDeviceId = deviceIds.stream() - .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage())); - final Versionstamp versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral))); + final Optional versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); assertNotNull(versionstamp); - final Map storedMessagesByDeviceId = deviceIds.stream() - .collect(Collectors.toMap(Function.identity(), deviceId -> { - try { - return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp)); - } catch (final InvalidProtocolBufferException e) { - throw new UncheckedIOException(e); - } - })); + if (expectMessagesInserted) { + assertTrue(versionstamp.isPresent()); + final Map storedMessagesByDeviceId = deviceIds.stream() + .collect(Collectors.toMap(Function.identity(), deviceId -> { + try { + return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp.get())); + } catch (final InvalidProtocolBufferException e) { + throw new UncheckedIOException(e); + } + })); - assertEquals(messagesByDeviceId, storedMessagesByDeviceId); - assertEquals(versionstamp, getLastMessageVersionstamp(aci), - "last message versionstamp should be the versionstamp of the last insert transaction"); + assertEquals(messagesByDeviceId, storedMessagesByDeviceId); + } else { + assertTrue(versionstamp.isEmpty()); + } + + if (expectVersionstampUpdated) { + assertEquals(versionstamp, getMessagesAvailableWatch(aci), + "messages available versionstamp should be the versionstamp of the last insert transaction"); + } else { + assertTrue(getMessagesAvailableWatch(aci).isEmpty()); + } + } + + private static Stream insert() { + return Stream.of( + Arguments.argumentSet("Non-ephemeral messages with all devices online", + 10L, false, true, true), + Arguments.argumentSet( + "Ephemeral messages with presence updated exactly at the second before which the device would be considered offline", + 300L, true, true, true), + Arguments.argumentSet("Non-ephemeral messages for with all devices offline", + 310L, false, true, false), + Arguments.argumentSet("Ephemeral messages with all devices offline", + 310L, true, false, false) + ); } @Test void versionstampCorrectlyUpdatedOnMultipleInserts() { final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); - foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage())).join(); - final Versionstamp secondMessageVersionstamp = foundationDbMessageStore.insert(aci, - Map.of(Device.PRIMARY_ID, generateRandomMessage())).join(); - assertEquals(secondMessageVersionstamp, getLastMessageVersionstamp(aci)); + writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L); + foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); + final Optional secondMessageVersionstamp = foundationDbMessageStore.insert(aci, + Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join(); + assertEquals(secondMessageVersionstamp, getMessagesAvailableWatch(aci)); } - private static MessageProtos.Envelope generateRandomMessage() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void insertOnlyOneDevicePresent(final boolean ephemeral) { + final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); + final List deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6) + .mapToObj(i -> (byte) i) + .toList(); + // Only 1 device has a recent presence, the others do not have presence keys present. + writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L); + final Map messagesByDeviceId = deviceIds.stream() + .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral))); + final Optional versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + assertNotNull(versionstamp); + assertTrue(versionstamp.isPresent(), + "versionstamp should be present since at least one message should be inserted"); + + assertArrayEquals( + messagesByDeviceId.get(Device.PRIMARY_ID).toByteArray(), + getMessageByVersionstamp(aci, Device.PRIMARY_ID, versionstamp.get()), + "Message for primary should always be stored since it has a recently updated presence"); + + if (ephemeral) { + assertTrue(IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID + 6) + .mapToObj(deviceId -> getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get())) + .allMatch(Objects::isNull), "Ephemeral messages for non-present devices must not be stored"); + } else { + IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID) + .forEach(deviceId -> { + final byte[] messageBytes = getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get()); + assertEquals(messagesByDeviceId.get((byte) deviceId).toByteArray(), messageBytes, + "Non-ephemeral messages must always be stored"); + }); + } + + } + + @ParameterizedTest + @MethodSource + void isClientPresent(final byte[] presenceValueBytes, final boolean expectPresent) { + assertEquals(expectPresent, foundationDbMessageStore.isClientPresent(presenceValueBytes)); + } + + static Stream isClientPresent() { + return Stream.of( + Arguments.argumentSet("Presence value doesn't exist", + null, false), + Arguments.argumentSet("Presence updated recently", + Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(5))), true), + Arguments.argumentSet("Presence updated same second as current time", + Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(0))), true), + Arguments.argumentSet("Presence updated exactly at the second before which it would have been considered offline", + Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(300))), true), + Arguments.argumentSet("Presence expired", + Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(400))), false) + ); + } + + private static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral) { return MessageProtos.Envelope.newBuilder() .setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(16))) + .setEphemeral(ephemeral) .build(); } @@ -90,12 +192,31 @@ class FoundationDbMessageStoreTest { }).join(); } - private Versionstamp getLastMessageVersionstamp(final AciServiceIdentifier aci) { + private Optional getMessagesAvailableWatch(final AciServiceIdentifier aci) { return FOUNDATION_DB_EXTENSION.getDatabase() - .read(transaction -> transaction.get(foundationDbMessageStore.getLastMessageKey(aci)) - .thenApply(Tuple::fromBytes) - .thenApply(t -> t.getVersionstamp(0))) + .read(transaction -> transaction.get(foundationDbMessageStore.getMessagesAvailableWatchKey(aci)) + .thenApply(value -> value == null ? null : Tuple.fromBytes(value).getVersionstamp(0)) + .thenApply(Optional::ofNullable)) .join(); } + private void writePresenceKey(final AciServiceIdentifier aci, final byte deviceId, final int serverId, + final long secondsBeforeCurrentTime) { + FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> { + final byte[] presenceKey = foundationDbMessageStore.getPresenceKey(aci, deviceId); + final long presenceUpdateEpochSeconds = getEpochSecondsBeforeClock(secondsBeforeCurrentTime); + final long presenceValue = constructPresenceValue(serverId, presenceUpdateEpochSeconds); + transaction.set(presenceKey, Conversions.longToByteArray(presenceValue)); + return null; + }); + } + + private static long getEpochSecondsBeforeClock(final long secondsBefore) { + return CLOCK.instant().minusSeconds(secondsBefore).getEpochSecond(); + } + + private static long constructPresenceValue(final int serverId, final long presenceUpdateEpochSeconds) { + return (long) (serverId & 0x0ffff) << 48 | (presenceUpdateEpochSeconds & 0x0000ffffffffffffL); + } + }