diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java index 5b0e4a9a2..376c41563 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStore.java @@ -111,8 +111,6 @@ public class FoundationDbMessageStore { new ArrayList<>(); messagesByShardId.forEach((shardId, messagesForShard) -> { - final Database shard = databases[shardId]; - int start = 0, current = 0; int estimatedTransactionSize = 0; @@ -123,7 +121,7 @@ public class FoundationDbMessageStore { .sum(); if (estimatedTransactionSize > MAX_MESSAGE_CHUNK_SIZE) { - chunkFutures.add(insertChunk(shard, messagesForShard.subList(start, current))); + chunkFutures.add(insertChunk(shardId, messagesForShard.subList(start, current))); start = current; estimatedTransactionSize = 0; @@ -133,7 +131,7 @@ public class FoundationDbMessageStore { } assert start < messagesForShard.size(); - chunkFutures.add(insertChunk(shard, messagesForShard.subList(start, messagesForShard.size()))); + chunkFutures.add(insertChunk(shardId, messagesForShard.subList(start, messagesForShard.size()))); }); return CompletableFuture.allOf(chunkFutures.toArray(CompletableFuture[]::new)) @@ -146,7 +144,7 @@ public class FoundationDbMessageStore { } private CompletableFuture>> insertChunk( - final Database database, + final int shardId, final List>> messagesByAccountIdentifier) { final Map>> insertFuturesByAci = new HashMap<>(); @@ -158,7 +156,7 @@ public class FoundationDbMessageStore { .map(MessageProtos.Envelope::getEphemeral) .orElseThrow(() -> new IllegalStateException("One or more bundles is empty")); - return database.runAsync(transaction -> { + return databases[shardId].runAsync(transaction -> { messagesByAccountIdentifier.forEach(entry -> insertFuturesByAci.put(entry.getKey(), insert(entry.getKey(), entry.getValue(), transaction))); @@ -171,7 +169,7 @@ public class FoundationDbMessageStore { .anyMatch(isPresent -> isPresent); if (anyClientPresent || !ephemeral) { return transaction.getVersionstamp() - .thenApply(versionstampBytes -> Optional.of(Versionstamp.complete(versionstampBytes))); + .thenApply(versionstampBytes -> Optional.of(Versionstamp.complete(versionstampBytes, shardId))); } return CompletableFuture.completedFuture(Optional.empty()); }); @@ -222,7 +220,7 @@ public class FoundationDbMessageStore { if (isPresent || !message.getEphemeral()) { transaction.mutate(MutationType.SET_VERSIONSTAMPED_KEY, getDeviceQueueSubspace(aci, deviceId) - .packWithVersionstamp(Tuple.from(Versionstamp.incomplete())), message.toByteArray()); + .packWithVersionstamp(Tuple.from(Versionstamp.incomplete(hashAciToShardNumber(aci)))), message.toByteArray()); } return isPresent; @@ -241,7 +239,7 @@ public class FoundationDbMessageStore { if (anyClientPresent) { transaction.mutate(MutationType.SET_VERSIONSTAMPED_VALUE, getMessagesAvailableWatchKey(aci), - Tuple.from(Versionstamp.incomplete()).packWithVersionstamp()); + Tuple.from(Versionstamp.incomplete(hashAciToShardNumber(aci))).packWithVersionstamp()); } return presenceByDeviceId; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java index b47f62b49..d2396b05f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/foundationdb/FoundationDbMessageStoreTest.java @@ -74,23 +74,34 @@ class FoundationDbMessageStoreTest { @ParameterizedTest @MethodSource - void insert(final long presenceUpdatedBeforeSeconds, final boolean ephemeral, final boolean expectMessagesInserted, - final boolean expectVersionstampUpdated, final boolean expectPresenceState) { + void insert(final long presenceUpdatedBeforeSeconds, + final boolean ephemeral, + final boolean expectMessagesInserted, + final boolean expectVersionstampUpdated, + final boolean expectPresenceState) { + final AciServiceIdentifier aci = new AciServiceIdentifier(UUID.randomUUID()); final List deviceIds = IntStream.range(Device.PRIMARY_ID, Device.PRIMARY_ID + 6) .mapToObj(i -> (byte) i) .toList(); + deviceIds.forEach(deviceId -> writePresenceKey(aci, deviceId, 1, presenceUpdatedBeforeSeconds)); + final Map messagesByDeviceId = deviceIds.stream() .collect(Collectors.toMap(Function.identity(), _ -> generateRandomMessage(ephemeral))); - final Map result = foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); - assertNotNull(result); + + final Map result = + foundationDbMessageStore.insert(aci, messagesByDeviceId).join(); + + assertTrue(result.keySet().containsAll(deviceIds)); final Optional returnedVersionstamp = result.values().stream().findFirst() .flatMap(FoundationDbMessageStore.InsertResult::versionstamp); + if (expectMessagesInserted) { assertTrue(returnedVersionstamp.isPresent()); assertTrue(result.values().stream().allMatch(insertResult -> returnedVersionstamp.equals(insertResult.versionstamp()))); + final Map storedMessagesByDeviceId = deviceIds.stream() .collect(Collectors.toMap(Function.identity(), deviceId -> { try { @@ -109,7 +120,7 @@ class FoundationDbMessageStoreTest { if (expectVersionstampUpdated) { final Optional messagesAvailableWatchVersionstamp = getMessagesAvailableWatch(aci); assertTrue(messagesAvailableWatchVersionstamp.isPresent()); - assertEquals(returnedVersionstamp, messagesAvailableWatchVersionstamp, + assertEquals(messagesAvailableWatchVersionstamp, returnedVersionstamp, "messages available versionstamp should be the versionstamp of the last insert transaction"); } else { assertTrue(getMessagesAvailableWatch(aci).isEmpty());