Fetch messages to persist via the same pathway as messages to deliver

This commit is contained in:
Jon Chambers
2026-03-13 09:46:14 -04:00
committed by Jon Chambers
parent 4578150e5a
commit dc8e03bd40
10 changed files with 235 additions and 222 deletions

View File

@@ -20,7 +20,6 @@ import io.micrometer.core.instrument.Timer;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
@@ -88,8 +87,6 @@ public class MessagePersister implements Managed {
private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis();
private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3;
private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
public MessagePersister(final MessagesCache messagesCache,
@@ -240,11 +237,8 @@ public class MessagePersister implements Managed {
final Account account = accountAndDevice.getT1();
final Device device = accountAndDevice.getT2();
return Mono.fromCallable(() -> {
persistQueue(account, device, tags);
return 1;
})
.subscribeOn(persistQueueScheduler)
return persistQueue(account, device, tags)
.thenReturn(1)
.retryWhen(retryBackoffSpec
// Don't retry with backoff for persistence exceptions
.filter(e -> !(e instanceof MessagePersistenceException)))
@@ -270,7 +264,7 @@ public class MessagePersister implements Managed {
}
@VisibleForTesting
void persistQueue(final Account account, final Device device, final Tags baseTags) throws MessagePersistenceException {
Mono<Void> persistQueue(final Account account, final Device device, final Tags baseTags) {
final UUID accountUuid = account.getUuid();
final byte deviceId = device.getId();
@@ -279,122 +273,114 @@ public class MessagePersister implements Managed {
.orElse("unknown"));
final Timer.Sample sample = Timer.start();
final Tags tags = baseTags.and(platformTag);
messagesCache.lockQueueForPersistence(accountUuid, deviceId);
return Flux.usingWhen(
messagesCache.lockQueueForPersistence(accountUuid, deviceId)
.thenReturn(true),
_ -> Flux.from(messagesCache.getMessagesToPersist(accountUuid, deviceId))
.buffer(MESSAGE_BATCH_LIMIT)
.flatMap(messages -> {
final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count();
final int nonUrgentMessageCount = messages.size() - urgentMessageCount;
try {
int messageCount = 0;
List<MessageProtos.Envelope> messages;
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount);
Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags)
.increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum());
int consecutiveEmptyCacheRemovals = 0;
return Mono.fromRunnable(() -> messagesManager.persistMessages(accountUuid, device, messages))
.subscribeOn(persistQueueScheduler)
.thenReturn(messages.size());
}, 1)
.reduce(0, Integer::sum)
.onErrorResume(ItemCollectionSizeLimitExceededException.class, _ -> {
final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
// may throw, in which case we'll retry later by the usual mechanism
if (isPrimary) {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages",
account.getUuid(), deviceId);
do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
return trimQueue(account, device)
.then(Mono.error(new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed")));
} else {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId);
final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count();
final int nonUrgentMessageCount = messages.size() - urgentMessageCount;
final Tags tags = baseTags.and(platformTag);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount);
Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount);
Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags)
.increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum());
int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages);
messageCount += messages.size();
if (messagesRemovedFromCache == 0) {
consecutiveEmptyCacheRemovals += 1;
} else {
consecutiveEmptyCacheRemovals = 0;
}
if (consecutiveEmptyCacheRemovals > CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT) {
throw new MessagePersistenceException("persistence failure loop detected");
}
} while (!messages.isEmpty());
DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME)
.tags(Tags.of(platformTag))
.register(Metrics.globalRegistry)
.record(messageCount);
} catch (final ItemCollectionSizeLimitExceededException e) {
final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
// may throw, in which case we'll retry later by the usual mechanism
if (isPrimary) {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages",
account.getUuid(), deviceId);
trimQueue(account, deviceId);
throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed");
} else {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId);
accountsManager.removeDevice(account, deviceId);
}
} finally {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
sample.stop(PERSIST_QUEUE_TIMER);
}
return Mono.fromRunnable(() -> accountsManager.removeDevice(account, deviceId))
.subscribeOn(persistQueueScheduler)
.then(Mono.empty());
}
})
.doOnSuccess(messagesPersisted -> {
if (messagesPersisted != null) {
DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME)
.tags(Tags.of(platformTag))
.register(Metrics.globalRegistry)
.record(messagesPersisted);
}
})
.doOnTerminate(() -> sample.stop(PERSIST_QUEUE_TIMER)),
_ -> messagesCache.unlockQueueForPersistence(accountUuid, deviceId))
.then();
}
private void trimQueue(final Account account, byte deviceId) {
private Mono<Void> trimQueue(final Account account, final Device device) {
final UUID aci = account.getIdentifier(IdentityType.ACI);
final byte deviceId = device.getId();
final Optional<Device> maybeDevice = account.getDevice(deviceId);
if (maybeDevice.isEmpty()) {
logger.warn("Not deleting messages for overfull queue {}::{}, deviceId {} does not exist",
aci, deviceId, deviceId);
return;
}
final Device device = maybeDevice.get();
// Calculate how many bytes we should trim
final long cachedMessageBytes = messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId).join();
final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration()
.getMessagePersisterConfiguration()
.getTrimOversizedQueueExtraRoomRatio();
final long targetDeleteBytes = Math.round(cachedMessageBytes * extraRoomRatio);
final AtomicLong oldestMessage = new AtomicLong(0L);
final AtomicLong newestMessage = new AtomicLong(0L);
final AtomicLong bytesDeleted = new AtomicLong(0L);
// Iterate from the oldest message until we've removed targetDeleteBytes
final Pair<Long, Long> outcomes = Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false))
.concatMap(envelope -> {
if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes) {
return Mono.just(Optional.<MessageProtos.Envelope>empty());
}
oldestMessage.compareAndSet(0L, envelope.getServerTimestamp());
newestMessage.set(envelope.getServerTimestamp());
return Mono.just(Optional.of(envelope));
})
.takeWhile(Optional::isPresent)
.flatMap(maybeEnvelope -> {
// We know this must be present because we `takeWhile` values are present
final MessageProtos.Envelope envelope = maybeEnvelope.orElseThrow(AssertionError::new);
TRIMMED_MESSAGE_COUNTER.increment();
TRIMMED_MESSAGE_BYTES_COUNTER.increment(envelope.getSerializedSize());
return Mono
.fromCompletionStage(() -> messagesManager
.delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp()))
.retryWhen(retryBackoffSpec)
.map(Optional::isPresent);
})
.reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted
? Pair.of(acc.getLeft() + 1, acc.getRight())
: Pair.of(acc.getLeft(), acc.getRight() + 1))
.blockOptional()
.orElseGet(() -> Pair.of(0L, 0L));
final AtomicLong cachedMessageBytes = new AtomicLong(0L);
final AtomicLong targetDeleteBytes = new AtomicLong(0L);
logger.warn(
"Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.",
aci, deviceId,
Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()),
targetDeleteBytes, cachedMessageBytes,
outcomes.getLeft(), outcomes.getRight());
return Mono.fromFuture(() -> messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId))
.flatMap(estimatedPersistedQueueSize -> {
cachedMessageBytes.set(estimatedPersistedQueueSize);
targetDeleteBytes.set(Math.round(estimatedPersistedQueueSize * extraRoomRatio));
return Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false))
.concatMap(envelope -> {
if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes.get()) {
return Mono.just(Optional.<MessageProtos.Envelope>empty());
}
oldestMessage.compareAndSet(0L, envelope.getServerTimestamp());
newestMessage.set(envelope.getServerTimestamp());
return Mono.just(Optional.of(envelope));
})
.takeWhile(Optional::isPresent)
.flatMap(maybeEnvelope -> {
// We know this must be present because we `takeWhile` values are present
final MessageProtos.Envelope envelope = maybeEnvelope.orElseThrow(AssertionError::new);
TRIMMED_MESSAGE_COUNTER.increment();
TRIMMED_MESSAGE_BYTES_COUNTER.increment(envelope.getSerializedSize());
return Mono
.fromCompletionStage(() -> messagesManager
.delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp()))
.retryWhen(retryBackoffSpec)
.map(Optional::isPresent);
})
.reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted
? Pair.of(acc.getLeft() + 1, acc.getRight())
: Pair.of(acc.getLeft(), acc.getRight() + 1));
})
.doOnSuccess(outcomes -> {
if (outcomes != null) {
logger.warn(
"Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.",
aci, deviceId,
Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()),
targetDeleteBytes, cachedMessageBytes,
outcomes.getLeft(), outcomes.getRight());
}
})
.then();
}
@VisibleForTesting

View File

@@ -28,6 +28,7 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@@ -139,7 +140,6 @@ public class MessagesCache {
private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"));
private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload"));
private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get"));
private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid"));
private final Timer removeRecipientViewTimer = Metrics.timer(name(MessagesCache.class, "removeRecipientView"));
private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear"));
@@ -314,13 +314,30 @@ public class MessagesCache {
.toCompletableFuture();
}
public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDevice) {
public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId) {
return get(destinationUuid,
destinationDeviceId,
clock.instant().minus(MAX_EPHEMERAL_MESSAGE_DELAY),
false);
}
final long earliestAllowableEphemeralTimestamp =
clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
Publisher<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDeviceId) {
return Flux.from(get(accountUuid,
destinationDeviceId,
// Discard all ephemeral messages when persisting
Instant.ofEpochMilli(Long.MAX_VALUE),
true));
}
final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDevice,
earliestAllowableEphemeralTimestamp, PAGE_SIZE)
private Publisher<MessageProtos.Envelope> get(final UUID destinationUuid,
final byte destinationDeviceId,
final Instant earliestAllowableEphemeralTimestamp,
final boolean bypassLock) {
final long earliestAllowableEphemeralTimestampMillis = earliestAllowableEphemeralTimestamp.toEpochMilli();
final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDeviceId,
earliestAllowableEphemeralTimestampMillis, PAGE_SIZE, bypassLock)
.publish()
// We expect exactly three subscribers to this base flux:
// 1. the websocket that delivers messages to clients
@@ -332,23 +349,23 @@ public class MessagesCache {
final Flux<MessageProtos.Envelope> messagesToPublish = allMessages
.filter(Predicate.not(envelope ->
isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp) || isStaleMrmMessage(envelope)));
isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestampMillis) || isStaleMrmMessage(envelope)));
final Flux<MessageProtos.Envelope> staleEphemeralMessages = allMessages
.filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp));
discardStaleMessages(destinationUuid, destinationDevice, staleEphemeralMessages, staleEphemeralMessagesCounter, "ephemeral");
.filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestampMillis));
discardStaleMessages(destinationUuid, destinationDeviceId, staleEphemeralMessages, staleEphemeralMessagesCounter, "ephemeral");
final Flux<MessageProtos.Envelope> staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage)
// clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
.map(envelope -> envelope.toBuilder().clearSharedMrmKey().build());
discardStaleMessages(destinationUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm");
discardStaleMessages(destinationUuid, destinationDeviceId, staleMrmMessages, staleMrmMessagesCounter, "mrm");
return messagesToPublish.name(GET_FLUX_NAME)
.tap(Micrometer.metrics(Metrics.globalRegistry));
}
public Mono<Long> getEarliestUndeliveredTimestamp(final UUID destinationUuid, final byte destinationDevice) {
return getAllMessages(destinationUuid, destinationDevice, -1, 1)
return getAllMessages(destinationUuid, destinationDevice, -1, 1, true)
.next()
.map(MessageProtos.Envelope::getServerTimestamp);
}
@@ -380,18 +397,21 @@ public class MessagesCache {
}
@VisibleForTesting
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final byte destinationDevice,
final long earliestAllowableEphemeralTimestamp, final int pageSize) {
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid,
final byte destinationDevice,
final long earliestAllowableEphemeralTimestamp,
final int pageSize,
final boolean bypassLock) {
// fetch messages by page
return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize)
return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize, bypassLock)
.expand(queueItemsAndLastMessageId -> {
// expand() is breadth-first, so each page will be published in order
if (queueItemsAndLastMessageId.first().isEmpty()) {
return Mono.empty();
}
return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize);
return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize, bypassLock);
})
.limitRate(1)
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors
@@ -536,10 +556,13 @@ public class MessagesCache {
.subscribe();
}
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
long messageId, int pageSize) {
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid,
final byte destinationDevice,
final long messageId,
final int pageSize,
final boolean bypassLock) {
return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId)
return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId, bypassLock)
.map(queueItems -> {
logger.trace("Processing page: {}", messageId);
@@ -590,35 +613,6 @@ public class MessagesCache {
.toFuture();
}
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
final Timer.Sample sample = Timer.start();
final List<byte[]> messages = redisCluster.withBinaryCluster(connection ->
connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit));
final Flux<MessageProtos.Envelope> allMessages = parseAndFetchMrms(Flux.fromIterable(messages), destinationDevice);
final Flux<MessageProtos.Envelope> messagesToPersist = allMessages
.filter(Predicate.not(envelope ->
envelope.getEphemeral() || isStaleMrmMessage(envelope)));
final Flux<MessageProtos.Envelope> ephemeralMessages = allMessages
.filter(MessageProtos.Envelope::getEphemeral);
discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral");
final Flux<MessageProtos.Envelope> staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage)
// clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data
.map(envelope -> envelope.toBuilder().clearSharedMrmKey().build());
discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm");
return messagesToPersist
.collectList()
.doOnTerminate(() -> sample.stop(getMessagesTimer))
.block(Duration.ofSeconds(5));
}
private Flux<MessageProtos.Envelope> parseAndFetchMrms(final Flux<byte[]> serializedMessages, final byte destinationDevice) {
return serializedMessages
.mapNotNull(message -> {
@@ -720,13 +714,14 @@ public class MessagesCache {
ScanArgs.Builder.matches("user_queue::*").limit(scanCount)));
}
void lockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
redisCluster.useBinaryCluster(
connection -> connection.sync().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE));
Mono<Void> lockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
return redisCluster.withBinaryCluster(
connection -> connection.reactive().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE))
.then();
}
void unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
unlockQueueScript.execute(accountUuid, deviceId);
Mono<Void> unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
return unlockQueueScript.execute(accountUuid, deviceId);
}
static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) {

View File

@@ -28,15 +28,20 @@ class MessagesCacheGetItemsScript {
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
int limit, long afterMessageId) {
Mono<List<byte[]>> execute(final UUID destinationUuid,
final byte destinationDevice,
final int limit,
final long afterMessageId,
final boolean bypassLock) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey
);
final List<byte[]> args = List.of(
String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit
String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId
String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8), // afterMessageId
String.valueOf(bypassLock).getBytes(StandardCharsets.UTF_8) // bypassLock
);
//noinspection unchecked
return getItemsScript.executeBinaryReactive(keys, args)

View File

@@ -6,15 +6,17 @@
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.push.ClientEvent;
import org.whispersystems.textsecuregcm.push.MessagesPersistedEvent;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.util.ResilienceUtil;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
/**
* Unlocks a message queue for persistence/message retrieval.
@@ -33,13 +35,14 @@ class MessagesCacheUnlockQueueScript {
ClusterLuaScript.fromResource(redisCluster, "lua/unlock_queue.lua", ScriptOutputType.STATUS);
}
void execute(final UUID accountIdentifier, final byte deviceId) {
Mono<Void> execute(final UUID accountIdentifier, final byte deviceId) {
final List<byte[]> keys = List.of(
MessagesCache.getPersistInProgressKey(accountIdentifier, deviceId), // persistInProgressKey
RedisMessageAvailabilityManager.getClientEventChannel(accountIdentifier, deviceId) // eventChannelKey
);
ResilienceUtil.getGeneralRedisRetry(MessagesCache.RETRY_NAME)
.executeRunnable(() -> unlockQueueScript.executeBinary(keys, MESSAGES_PERSISTED_EVENT_ARGS));
return unlockQueueScript.executeBinaryReactive(keys, MESSAGES_PERSISTED_EVENT_ARGS)
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
.then();
}
}

View File

@@ -2,14 +2,17 @@
-- returns a list of all envelopes and their queue-local IDs
local queueKey = KEYS[1] -- sorted set of all Envelopes for a device, scored by queue-local ID
local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persistent and must not be read
local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persisted and must not be read
local limit = ARGV[1] -- [number] the maximum number of messages to return
local afterMessageId = ARGV[2] -- [number] a queue-local ID to exclusively start after, to support pagination. Use -1 to start at the beginning
local bypassLock = ARGV[3] -- [string] whether to bypass the persistence lock (i.e. when fetching messages for persistence)
local locked = redis.call("GET", queueLockKey)
if bypassLock ~= "true" then
local locked = redis.call("GET", queueLockKey)
if locked then
return {}
if locked then
return {}
end
end
if afterMessageId == "null" or afterMessageId == nil then

View File

@@ -218,7 +218,7 @@ class MessagePersisterIntegrationTest {
}
@Test
void testPersistFirstPageDiscarded() throws MessagePersistenceException {
void testPersistFirstPageDiscarded() {
final int discardableMessages = MessagePersister.MESSAGE_BATCH_LIMIT * 2;
final int persistableMessages = MessagePersister.MESSAGE_BATCH_LIMIT + 1;
@@ -245,7 +245,7 @@ class MessagePersisterIntegrationTest {
expectedMessages.add(message);
}
messagePersister.persistQueue(account, account.getDevice(Device.PRIMARY_ID).orElseThrow(), Tags.empty());
messagePersister.persistQueue(account, account.getDevice(Device.PRIMARY_ID).orElseThrow(), Tags.empty()).block();
final DynamoDbClient dynamoDB = DYNAMO_DB_EXTENSION.getDynamoDbClient();

View File

@@ -64,6 +64,7 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.util.retry.Retry;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
@@ -176,6 +177,25 @@ class MessagePersisterTest {
resubscribeRetryExecutorService.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
void persistQueue() {
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount,
CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1)));
messagePersister.persistQueue(destinationAccount, destinationAccount.getDevice(DESTINATION_DEVICE_ID).orElseThrow(), Tags.empty())
.block();
@SuppressWarnings("unchecked") final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor =
ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce())
.store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void persistNextNodeNoQueues() {
assertEquals(0, messagePersister.persistNextNode());
@@ -407,39 +427,6 @@ class MessagePersisterTest {
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void persistNodePersistQueueMessagePersistenceException() {
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount,
CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1)));
// Provoke a MessagePersistenceException
when(messagesManager.persistMessages(any(), any(), any())).thenReturn(0);
assertEquals(0, messagePersister.persistNode(
getNodeWithKey(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID))));
// We use this as a proxy for attempts to persist messages; for a MessagePersistenceException, we should NOT retry,
// and this should happen exactly once
verify(messagesCache).lockQueueForPersistence(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
verify(messagesDynamoDb, never()).store(any(), any(), any());
}
@Test
void testPersistQueueRetryLoop() {
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1)));
// returning `0` indicates something not working correctly
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenReturn(0);
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class,
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty())));
}
@Test
void testUnlinkOnFullQueue() {
final int messageCount = 1;
@@ -469,12 +456,15 @@ class MessagePersisterTest {
final Device destination = mock(Device.class);
when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID);
when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination));
when(destinationAccount.getDevices())
.thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination));
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(messagesManager.persistMessages(any(), any(), any()))
.thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()));
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block());
verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID);
}
@@ -520,8 +510,9 @@ class MessagePersisterTest {
when(messagesManager.delete(any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
assertThrows(MessagePersistenceException.class, () ->
messagePersister.persistQueue(destinationAccount, primary, Tags.empty()));
StepVerifier.create(messagePersister.persistQueue(destinationAccount, primary, Tags.empty()))
.expectError(MessagePersistenceException.class)
.verify();
verify(messagesManager, times(expectedClearedGuids.size()))
.delete(eq(DESTINATION_ACCOUNT_UUID), eq(primary), argThat(expectedClearedGuids::contains), anyLong());
@@ -563,7 +554,7 @@ class MessagePersisterTest {
when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException());
assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()));
assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block());
}
private static RedisClusterNode getNodeWithKey(final byte[] key) {

View File

@@ -49,7 +49,7 @@ class MessagesCacheGetItemsScriptTest {
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1)
final List<byte[]> messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1, false)
.block(Duration.ofSeconds(1));
assertNotNull(messageAndScores);

View File

@@ -68,7 +68,7 @@ class MessagesCacheInsertScriptTest {
final MessagesCacheGetItemsScript getItemsScript =
new MessagesCacheGetItemsScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
final List<byte[]> queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0)
final List<byte[]> queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0, false)
.blockOptional()
.orElseGet(Collections::emptyList);

View File

@@ -139,7 +139,7 @@ class MessagesCacheTest {
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10, false)
.count()
.blockOptional()
.orElse(0L));
@@ -319,15 +319,41 @@ class MessagesCacheTest {
expectedMessages.add(message);
}
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block();
assertTrue(get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount).isEmpty());
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block();
assertEquals(expectedMessages, get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@Test
void testGetMessagesToPersistLockedForPersistence() {
final int messageCount = 100;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
expectedMessages.add(message);
}
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block();
try {
assertEquals(expectedMessages,
Flux.from(messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID))
.collectList()
.blockOptional()
.orElseThrow());
} finally {
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block();
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testGetMessagesPublisher(final boolean expectStale) throws Exception {
@@ -385,7 +411,7 @@ class MessagesCacheTest {
.get(5, TimeUnit.SECONDS);
final List<MessageProtos.Envelope> messages = messagesCache.getAllMessages(DESTINATION_UUID,
DESTINATION_DEVICE_ID, 0, 10)
DESTINATION_DEVICE_ID, 0, 10, false)
.collectList()
.toFuture().get(5, TimeUnit.SECONDS);
@@ -684,7 +710,11 @@ class MessagesCacheTest {
.build();
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join();
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);
final List<MessageProtos.Envelope> messages =
Flux.from(messagesCache.getMessagesToPersist(destinationUuid, deviceId))
.collectList()
.blockOptional()
.orElseThrow();
if (!sharedMrmKeyPresent) {
assertEquals(1, messages.size());
@@ -778,7 +808,7 @@ class MessagesCacheTest {
.thenReturn(Flux.from(emptyFinalPagePublisher))
.thenReturn(Flux.empty());
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10);
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10, false);
// Why initialValue = 3?
// 1. messagesCache.getAllMessages() above produces the first call