Automatically trim primary queue when cache cannot be persisted

This commit is contained in:
Ravi Khadiwala
2025-02-28 11:11:42 -06:00
committed by ravi-signal
parent 8517eef3fe
commit 09b50383d7
6 changed files with 280 additions and 35 deletions

View File

@@ -10,12 +10,16 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNotNull;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
@@ -26,6 +30,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -34,8 +39,10 @@ import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -45,9 +52,12 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagePersisterConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
@@ -75,6 +85,8 @@ class MessagePersisterTest {
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
private static final double EXTRA_ROOM_RATIO = 2.0;
@BeforeEach
void setUp() throws Exception {
@@ -84,16 +96,21 @@ class MessagePersisterTest {
messagesDynamoDb = mock(MessagesDynamoDb.class);
accountsManager = mock(AccountsManager.class);
destinationAccount = mock(Account.class);;
destinationAccount = mock(Account.class);
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount));
when(accountsManager.removeDevice(any(), anyByte()))
.thenAnswer(invocation -> CompletableFuture.completedFuture(invocation.getArgument(0)));
when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID);
when(destinationAccount.getIdentifier(IdentityType.ACI)).thenReturn(DESTINATION_ACCOUNT_UUID);
when(destinationAccount.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(destinationAccount.getDevice(DESTINATION_DEVICE_ID)).thenReturn(Optional.of(DESTINATION_DEVICE));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfiguration.getMessagePersisterConfiguration())
.thenReturn(new DynamicMessagePersisterConfiguration(true, EXTRA_ROOM_RATIO));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
sharedExecutorService = Executors.newSingleThreadExecutor();
resubscribeRetryExecutorService = Executors.newSingleThreadScheduledExecutor();
@@ -285,6 +302,66 @@ class MessagePersisterTest {
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
}
@Test
void testTrimOnFullPrimaryQueue() {
final byte[] queueName = MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, Device.PRIMARY_ID);
final Instant now = Instant.now();
final List<MessageProtos.Envelope> cachedMessages = Stream.generate(() -> generateMessage(
DESTINATION_ACCOUNT_UUID, UUID.randomUUID(), now.getEpochSecond(), ThreadLocalRandom.current().nextInt(100)))
.limit(10)
.toList();
final long cacheSize = cachedMessages.stream().mapToLong(MessageProtos.Envelope::getSerializedSize).sum();
for (MessageProtos.Envelope envelope : cachedMessages) {
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), DESTINATION_ACCOUNT_UUID, Device.PRIMARY_ID, envelope);
}
final long expectedClearedBytes = (long) (cacheSize * EXTRA_ROOM_RATIO);
final int persistedMessageCount = 100;
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
final List<UUID> expectedClearedGuids = new ArrayList<>();
long total = 0L;
for (int i = 0; i < 100; i++) {
final UUID guid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateMessage(DESTINATION_ACCOUNT_UUID, guid, now.getEpochSecond(), 13);
persistedMessages.add(envelope);
if (total < expectedClearedBytes) {
total += envelope.getSerializedSize();
expectedClearedGuids.add(guid);
}
}
setNextSlotToPersist(SlotHash.getSlot(queueName));
final Device primary = mock(Device.class);
when(primary.getId()).thenReturn((byte) 1);
when(primary.isPrimary()).thenReturn(true);
when(primary.getFetchesMessages()).thenReturn(true);
when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primary));
when(messagesManager.persistMessages(any(UUID.class), any(), anyList()))
.thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(messagesManager.getMessagesForDeviceReactive(DESTINATION_ACCOUNT_UUID, primary, false))
.thenReturn(Flux.concat(
Flux.fromIterable(persistedMessages),
Flux.fromIterable(cachedMessages)));
when(messagesManager.delete(any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
assertTimeoutPreemptively(Duration.ofSeconds(10), () ->
messagePersister.persistNextQueues(Clock.systemUTC().instant()));
verify(messagesManager, times(expectedClearedGuids.size()))
.delete(eq(DESTINATION_ACCOUNT_UUID), eq(primary), argThat(expectedClearedGuids::contains), isNotNull());
verify(messagesManager, never()).delete(any(), any(), argThat(guid -> !expectedClearedGuids.contains(guid)), any());
final List<String> queuesToPersist = messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName),
Clock.systemUTC().instant(), 1);
assertEquals(queuesToPersist.size(), 1);
assertEquals(queuesToPersist.getFirst(), new String(queueName, StandardCharsets.UTF_8));
}
@Test
void testFailedUnlinkOnFullQueueThrowsForRetry() {
final String queueName = new String(
@@ -348,20 +425,23 @@ class MessagePersisterTest {
final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
final MessageProtos.Envelope envelope = generateMessage(
accountUuid, messageGuid, firstMessageTimestamp.toEpochMilli() + i, 256);
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join();
}
}
private MessageProtos.Envelope generateMessage(UUID accountUuid, UUID messageGuid, long messageTimestamp, int contentSize) {
return MessageProtos.Envelope.newBuilder()
.setDestinationServiceId(accountUuid.toString())
.setClientTimestamp(messageTimestamp)
.setServerTimestamp(messageTimestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(contentSize)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
}
private void setNextSlotToPersist(final int nextSlot) {
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));

View File

@@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -54,6 +55,7 @@ import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
@@ -488,6 +490,46 @@ class MessagesCacheTest {
}, "Shared MRM data should be deleted asynchronously");
}
@Test
void testMessagesToPersistReactive() {
final UUID destinationUuid = UUID.randomUUID();
final ServiceIdentifier serviceId = new AciServiceIdentifier(destinationUuid);
final byte deviceId = 1;
final List<MessageProtos.Envelope> expected = IntStream.range(0, 100)
.mapToObj(i -> {
if (i % 3 == 0) {
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(serviceId, deviceId);
byte[] sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
return generateRandomMessage(UUID.randomUUID(), serviceId, true)
.toBuilder()
// clear some things added by the helper
.clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build();
} else if (i % 13 == 0) {
return generateRandomMessage(UUID.randomUUID(), serviceId, true).toBuilder().setEphemeral(true).build();
} else {
return generateRandomMessage(UUID.randomUUID(), serviceId, true);
}
})
.filter(envelope -> !envelope.getEphemeral())
.toList();
for (MessageProtos.Envelope envelope : expected) {
messagesCache.insert(UUID.fromString(envelope.getServerGuid()), destinationUuid, deviceId, envelope).join();
}
final List<MessageProtos.Envelope> actual = messagesCache
.getMessagesToPersistReactive(destinationUuid, deviceId, 7).collectList().block();
assertEquals(expected.size(), actual.size());
for (int i = 0; i < actual.size(); i++) {
assertNotNull(actual.get(i).getContent());
assertEquals(actual.get(i).getServerGuid(), expected.get(i).getServerGuid());
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) {