Check presence before updating last message versionstamp

This commit is contained in:
ameya-signal
2025-08-18 10:16:00 -07:00
committed by GitHub
parent 4acb3b5ac7
commit a1d9c4c062
2 changed files with 231 additions and 50 deletions

View File

@@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.storage.foundationdb;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.apple.foundationdb.Database;
import com.apple.foundationdb.tuple.Tuple;
@@ -9,21 +11,32 @@ import com.apple.foundationdb.tuple.Versionstamp;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.UncheckedIOException;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.FoundationDbExtension;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
@@ -34,50 +47,139 @@ class FoundationDbMessageStoreTest {
private FoundationDbMessageStore foundationDbMessageStore;
private static final Clock CLOCK = Clock.fixed(Instant.ofEpochSecond(500), ZoneId.of("UTC"));
@BeforeEach
void setup() {
foundationDbMessageStore = new FoundationDbMessageStore(
new Database[]{FOUNDATION_DB_EXTENSION.getDatabase()},
Executors.newVirtualThreadPerTaskExecutor());
Executors.newVirtualThreadPerTaskExecutor(),
CLOCK);
}
@Test
void insert() {
@ParameterizedTest
@MethodSource
void insert(final long presenceUpdatedBeforeSeconds, final boolean ephemeral, final boolean expectMessagesInserted,
final boolean expectVersionstampUpdated) {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final List<Byte> deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6)
.mapToObj(i -> (byte) i)
.toList();
deviceIds.forEach(deviceId -> writePresenceKey(aci, deviceId, 1, presenceUpdatedBeforeSeconds));
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = deviceIds.stream()
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage()));
final Versionstamp versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral)));
final Optional<Versionstamp> versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
assertNotNull(versionstamp);
final Map<Byte, MessageProtos.Envelope> storedMessagesByDeviceId = deviceIds.stream()
.collect(Collectors.toMap(Function.identity(), deviceId -> {
try {
return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp));
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
}));
if (expectMessagesInserted) {
assertTrue(versionstamp.isPresent());
final Map<Byte, MessageProtos.Envelope> storedMessagesByDeviceId = deviceIds.stream()
.collect(Collectors.toMap(Function.identity(), deviceId -> {
try {
return MessageProtos.Envelope.parseFrom(getMessageByVersionstamp(aci, deviceId, versionstamp.get()));
} catch (final InvalidProtocolBufferException e) {
throw new UncheckedIOException(e);
}
}));
assertEquals(messagesByDeviceId, storedMessagesByDeviceId);
assertEquals(versionstamp, getLastMessageVersionstamp(aci),
"last message versionstamp should be the versionstamp of the last insert transaction");
assertEquals(messagesByDeviceId, storedMessagesByDeviceId);
} else {
assertTrue(versionstamp.isEmpty());
}
if (expectVersionstampUpdated) {
assertEquals(versionstamp, getMessagesAvailableWatch(aci),
"messages available versionstamp should be the versionstamp of the last insert transaction");
} else {
assertTrue(getMessagesAvailableWatch(aci).isEmpty());
}
}
private static Stream<Arguments> insert() {
return Stream.of(
Arguments.argumentSet("Non-ephemeral messages with all devices online",
10L, false, true, true),
Arguments.argumentSet(
"Ephemeral messages with presence updated exactly at the second before which the device would be considered offline",
300L, true, true, true),
Arguments.argumentSet("Non-ephemeral messages for with all devices offline",
310L, false, true, false),
Arguments.argumentSet("Ephemeral messages with all devices offline",
310L, true, false, false)
);
}
@Test
void versionstampCorrectlyUpdatedOnMultipleInserts() {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage())).join();
final Versionstamp secondMessageVersionstamp = foundationDbMessageStore.insert(aci,
Map.of(Device.PRIMARY_ID, generateRandomMessage())).join();
assertEquals(secondMessageVersionstamp, getLastMessageVersionstamp(aci));
writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L);
foundationDbMessageStore.insert(aci, Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join();
final Optional<Versionstamp> secondMessageVersionstamp = foundationDbMessageStore.insert(aci,
Map.of(Device.PRIMARY_ID, generateRandomMessage(false))).join();
assertEquals(secondMessageVersionstamp, getMessagesAvailableWatch(aci));
}
private static MessageProtos.Envelope generateRandomMessage() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void insertOnlyOneDevicePresent(final boolean ephemeral) {
final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID());
final List<Byte> deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6)
.mapToObj(i -> (byte) i)
.toList();
// Only 1 device has a recent presence, the others do not have presence keys present.
writePresenceKey(aci, Device.PRIMARY_ID, 1, 10L);
final Map<Byte, MessageProtos.Envelope> messagesByDeviceId = deviceIds.stream()
.collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral)));
final Optional<Versionstamp> versionstamp = foundationDbMessageStore.insert(aci, messagesByDeviceId).join();
assertNotNull(versionstamp);
assertTrue(versionstamp.isPresent(),
"versionstamp should be present since at least one message should be inserted");
assertArrayEquals(
messagesByDeviceId.get(Device.PRIMARY_ID).toByteArray(),
getMessageByVersionstamp(aci, Device.PRIMARY_ID, versionstamp.get()),
"Message for primary should always be stored since it has a recently updated presence");
if (ephemeral) {
assertTrue(IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID + 6)
.mapToObj(deviceId -> getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get()))
.allMatch(Objects::isNull), "Ephemeral messages for non-present devices must not be stored");
} else {
IntStream.range(Device.PRIMARY_ID + 1, Device.PRIMARY_ID)
.forEach(deviceId -> {
final byte[] messageBytes = getMessageByVersionstamp(aci, (byte) deviceId, versionstamp.get());
assertEquals(messagesByDeviceId.get((byte) deviceId).toByteArray(), messageBytes,
"Non-ephemeral messages must always be stored");
});
}
}
@ParameterizedTest
@MethodSource
void isClientPresent(final byte[] presenceValueBytes, final boolean expectPresent) {
assertEquals(expectPresent, foundationDbMessageStore.isClientPresent(presenceValueBytes));
}
static Stream<Arguments> isClientPresent() {
return Stream.of(
Arguments.argumentSet("Presence value doesn't exist",
null, false),
Arguments.argumentSet("Presence updated recently",
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(5))), true),
Arguments.argumentSet("Presence updated same second as current time",
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(0))), true),
Arguments.argumentSet("Presence updated exactly at the second before which it would have been considered offline",
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(300))), true),
Arguments.argumentSet("Presence expired",
Conversions.longToByteArray(constructPresenceValue(42, getEpochSecondsBeforeClock(400))), false)
);
}
private static MessageProtos.Envelope generateRandomMessage(final boolean ephemeral) {
return MessageProtos.Envelope.newBuilder()
.setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(16)))
.setEphemeral(ephemeral)
.build();
}
@@ -90,12 +192,31 @@ class FoundationDbMessageStoreTest {
}).join();
}
private Versionstamp getLastMessageVersionstamp(final AciServiceIdentifier aci) {
private Optional<Versionstamp> getMessagesAvailableWatch(final AciServiceIdentifier aci) {
return FOUNDATION_DB_EXTENSION.getDatabase()
.read(transaction -> transaction.get(foundationDbMessageStore.getLastMessageKey(aci))
.thenApply(Tuple::fromBytes)
.thenApply(t -> t.getVersionstamp(0)))
.read(transaction -> transaction.get(foundationDbMessageStore.getMessagesAvailableWatchKey(aci))
.thenApply(value -> value == null ? null : Tuple.fromBytes(value).getVersionstamp(0))
.thenApply(Optional::ofNullable))
.join();
}
private void writePresenceKey(final AciServiceIdentifier aci, final byte deviceId, final int serverId,
final long secondsBeforeCurrentTime) {
FOUNDATION_DB_EXTENSION.getDatabase().run(transaction -> {
final byte[] presenceKey = foundationDbMessageStore.getPresenceKey(aci, deviceId);
final long presenceUpdateEpochSeconds = getEpochSecondsBeforeClock(secondsBeforeCurrentTime);
final long presenceValue = constructPresenceValue(serverId, presenceUpdateEpochSeconds);
transaction.set(presenceKey, Conversions.longToByteArray(presenceValue));
return null;
});
}
private static long getEpochSecondsBeforeClock(final long secondsBefore) {
return CLOCK.instant().minusSeconds(secondsBefore).getEpochSecond();
}
private static long constructPresenceValue(final int serverId, final long presenceUpdateEpochSeconds) {
return (long) (serverId & 0x0ffff) << 48 | (presenceUpdateEpochSeconds & 0x0000ffffffffffffL);
}
}