diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 09037264e..977e58f71 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -20,7 +20,6 @@ import io.micrometer.core.instrument.Timer; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.List; import java.util.Locale; import java.util.Optional; import java.util.UUID; @@ -88,8 +87,6 @@ public class MessagePersister implements Managed { private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis(); - private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3; - private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class); public MessagePersister(final MessagesCache messagesCache, @@ -240,11 +237,8 @@ public class MessagePersister implements Managed { final Account account = accountAndDevice.getT1(); final Device device = accountAndDevice.getT2(); - return Mono.fromCallable(() -> { - persistQueue(account, device, tags); - return 1; - }) - .subscribeOn(persistQueueScheduler) + return persistQueue(account, device, tags) + .thenReturn(1) .retryWhen(retryBackoffSpec // Don't retry with backoff for persistence exceptions .filter(e -> !(e instanceof MessagePersistenceException))) @@ -270,7 +264,7 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void persistQueue(final Account account, final Device device, final Tags baseTags) throws MessagePersistenceException { + Mono persistQueue(final Account account, final Device device, final Tags baseTags) { final UUID accountUuid = account.getUuid(); final byte deviceId = device.getId(); @@ -279,122 +273,114 @@ public class MessagePersister implements Managed { .orElse("unknown")); final Timer.Sample sample = Timer.start(); + final Tags tags = baseTags.and(platformTag); - messagesCache.lockQueueForPersistence(accountUuid, deviceId); + return Flux.usingWhen( + messagesCache.lockQueueForPersistence(accountUuid, deviceId) + .thenReturn(true), + _ -> Flux.from(messagesCache.getMessagesToPersist(accountUuid, deviceId)) + .buffer(MESSAGE_BATCH_LIMIT) + .flatMap(messages -> { + final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count(); + final int nonUrgentMessageCount = messages.size() - urgentMessageCount; - try { - int messageCount = 0; - List messages; + Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount); + Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount); + Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags) + .increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum()); - int consecutiveEmptyCacheRemovals = 0; + return Mono.fromRunnable(() -> messagesManager.persistMessages(accountUuid, device, messages)) + .subscribeOn(persistQueueScheduler) + .thenReturn(messages.size()); + }, 1) + .reduce(0, Integer::sum) + .onErrorResume(ItemCollectionSizeLimitExceededException.class, _ -> { + final boolean isPrimary = deviceId == Device.PRIMARY_ID; + Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); + // may throw, in which case we'll retry later by the usual mechanism + if (isPrimary) { + logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages", + account.getUuid(), deviceId); - do { - messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); + return trimQueue(account, device) + .then(Mono.error(new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed"))); + } else { + logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId); - final int urgentMessageCount = (int) messages.stream().filter(MessageProtos.Envelope::getUrgent).count(); - final int nonUrgentMessageCount = messages.size() - urgentMessageCount; - - final Tags tags = baseTags.and(platformTag); - - Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "true")).increment(urgentMessageCount); - Metrics.counter(PERSISTED_MESSAGE_COUNTER_NAME, tags.and("urgent", "false")).increment(nonUrgentMessageCount); - Metrics.counter(PERSISTED_BYTES_COUNTER_NAME, tags) - .increment(messages.stream().mapToInt(MessageProtos.Envelope::getSerializedSize).sum()); - - int messagesRemovedFromCache = messagesManager.persistMessages(accountUuid, device, messages); - messageCount += messages.size(); - - if (messagesRemovedFromCache == 0) { - consecutiveEmptyCacheRemovals += 1; - } else { - consecutiveEmptyCacheRemovals = 0; - } - - if (consecutiveEmptyCacheRemovals > CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT) { - throw new MessagePersistenceException("persistence failure loop detected"); - } - - } while (!messages.isEmpty()); - - DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME) - .tags(Tags.of(platformTag)) - .register(Metrics.globalRegistry) - .record(messageCount); - } catch (final ItemCollectionSizeLimitExceededException e) { - final boolean isPrimary = deviceId == Device.PRIMARY_ID; - Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment(); - // may throw, in which case we'll retry later by the usual mechanism - if (isPrimary) { - logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages", - account.getUuid(), deviceId); - trimQueue(account, deviceId); - throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed"); - } else { - logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", accountUuid, deviceId); - accountsManager.removeDevice(account, deviceId); - } - } finally { - messagesCache.unlockQueueForPersistence(accountUuid, deviceId); - sample.stop(PERSIST_QUEUE_TIMER); - } + return Mono.fromRunnable(() -> accountsManager.removeDevice(account, deviceId)) + .subscribeOn(persistQueueScheduler) + .then(Mono.empty()); + } + }) + .doOnSuccess(messagesPersisted -> { + if (messagesPersisted != null) { + DistributionSummary.builder(QUEUE_SIZE_DISTRIBUTION_SUMMARY_NAME) + .tags(Tags.of(platformTag)) + .register(Metrics.globalRegistry) + .record(messagesPersisted); + } + }) + .doOnTerminate(() -> sample.stop(PERSIST_QUEUE_TIMER)), + _ -> messagesCache.unlockQueueForPersistence(accountUuid, deviceId)) + .then(); } - private void trimQueue(final Account account, byte deviceId) { + private Mono trimQueue(final Account account, final Device device) { final UUID aci = account.getIdentifier(IdentityType.ACI); + final byte deviceId = device.getId(); - final Optional maybeDevice = account.getDevice(deviceId); - if (maybeDevice.isEmpty()) { - logger.warn("Not deleting messages for overfull queue {}::{}, deviceId {} does not exist", - aci, deviceId, deviceId); - return; - } - final Device device = maybeDevice.get(); - - // Calculate how many bytes we should trim - final long cachedMessageBytes = messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId).join(); final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration() .getMessagePersisterConfiguration() .getTrimOversizedQueueExtraRoomRatio(); - final long targetDeleteBytes = Math.round(cachedMessageBytes * extraRoomRatio); final AtomicLong oldestMessage = new AtomicLong(0L); final AtomicLong newestMessage = new AtomicLong(0L); final AtomicLong bytesDeleted = new AtomicLong(0L); - // Iterate from the oldest message until we've removed targetDeleteBytes - final Pair outcomes = Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false)) - .concatMap(envelope -> { - if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes) { - return Mono.just(Optional.empty()); - } - oldestMessage.compareAndSet(0L, envelope.getServerTimestamp()); - newestMessage.set(envelope.getServerTimestamp()); - return Mono.just(Optional.of(envelope)); - }) - .takeWhile(Optional::isPresent) - .flatMap(maybeEnvelope -> { - // We know this must be present because we `takeWhile` values are present - final MessageProtos.Envelope envelope = maybeEnvelope.orElseThrow(AssertionError::new); - TRIMMED_MESSAGE_COUNTER.increment(); - TRIMMED_MESSAGE_BYTES_COUNTER.increment(envelope.getSerializedSize()); - return Mono - .fromCompletionStage(() -> messagesManager - .delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp())) - .retryWhen(retryBackoffSpec) - .map(Optional::isPresent); - }) - .reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted - ? Pair.of(acc.getLeft() + 1, acc.getRight()) - : Pair.of(acc.getLeft(), acc.getRight() + 1)) - .blockOptional() - .orElseGet(() -> Pair.of(0L, 0L)); + final AtomicLong cachedMessageBytes = new AtomicLong(0L); + final AtomicLong targetDeleteBytes = new AtomicLong(0L); - logger.warn( - "Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.", - aci, deviceId, - Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()), - targetDeleteBytes, cachedMessageBytes, - outcomes.getLeft(), outcomes.getRight()); + return Mono.fromFuture(() -> messagesCache.estimatePersistedQueueSizeBytes(aci, deviceId)) + .flatMap(estimatedPersistedQueueSize -> { + cachedMessageBytes.set(estimatedPersistedQueueSize); + targetDeleteBytes.set(Math.round(estimatedPersistedQueueSize * extraRoomRatio)); + + return Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false)) + .concatMap(envelope -> { + if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes.get()) { + return Mono.just(Optional.empty()); + } + oldestMessage.compareAndSet(0L, envelope.getServerTimestamp()); + newestMessage.set(envelope.getServerTimestamp()); + return Mono.just(Optional.of(envelope)); + }) + .takeWhile(Optional::isPresent) + .flatMap(maybeEnvelope -> { + // We know this must be present because we `takeWhile` values are present + final MessageProtos.Envelope envelope = maybeEnvelope.orElseThrow(AssertionError::new); + TRIMMED_MESSAGE_COUNTER.increment(); + TRIMMED_MESSAGE_BYTES_COUNTER.increment(envelope.getSerializedSize()); + return Mono + .fromCompletionStage(() -> messagesManager + .delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp())) + .retryWhen(retryBackoffSpec) + .map(Optional::isPresent); + }) + .reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted + ? Pair.of(acc.getLeft() + 1, acc.getRight()) + : Pair.of(acc.getLeft(), acc.getRight() + 1)); + }) + .doOnSuccess(outcomes -> { + if (outcomes != null) { + logger.warn( + "Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.", + aci, deviceId, + Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()), + targetDeleteBytes, cachedMessageBytes, + outcomes.getLeft(), outcomes.getRight()); + } + }) + .then(); } @VisibleForTesting diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 9db7dc655..5959de648 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -28,6 +28,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -139,7 +140,6 @@ public class MessagesCache { private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert")); private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload")); - private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid")); private final Timer removeRecipientViewTimer = Metrics.timer(name(MessagesCache.class, "removeRecipientView")); private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); @@ -314,13 +314,30 @@ public class MessagesCache { .toCompletableFuture(); } - public Publisher get(final UUID destinationUuid, final byte destinationDevice) { + public Publisher get(final UUID destinationUuid, final byte destinationDeviceId) { + return get(destinationUuid, + destinationDeviceId, + clock.instant().minus(MAX_EPHEMERAL_MESSAGE_DELAY), + false); + } - final long earliestAllowableEphemeralTimestamp = - clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); + Publisher getMessagesToPersist(final UUID accountUuid, final byte destinationDeviceId) { + return Flux.from(get(accountUuid, + destinationDeviceId, + // Discard all ephemeral messages when persisting + Instant.ofEpochMilli(Long.MAX_VALUE), + true)); + } - final Flux allMessages = getAllMessages(destinationUuid, destinationDevice, - earliestAllowableEphemeralTimestamp, PAGE_SIZE) + private Publisher get(final UUID destinationUuid, + final byte destinationDeviceId, + final Instant earliestAllowableEphemeralTimestamp, + final boolean bypassLock) { + + final long earliestAllowableEphemeralTimestampMillis = earliestAllowableEphemeralTimestamp.toEpochMilli(); + + final Flux allMessages = getAllMessages(destinationUuid, destinationDeviceId, + earliestAllowableEphemeralTimestampMillis, PAGE_SIZE, bypassLock) .publish() // We expect exactly three subscribers to this base flux: // 1. the websocket that delivers messages to clients @@ -332,23 +349,23 @@ public class MessagesCache { final Flux messagesToPublish = allMessages .filter(Predicate.not(envelope -> - isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp) || isStaleMrmMessage(envelope))); + isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestampMillis) || isStaleMrmMessage(envelope))); final Flux staleEphemeralMessages = allMessages - .filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestamp)); - discardStaleMessages(destinationUuid, destinationDevice, staleEphemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); + .filter(envelope -> isStaleEphemeralMessage(envelope, earliestAllowableEphemeralTimestampMillis)); + discardStaleMessages(destinationUuid, destinationDeviceId, staleEphemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); final Flux staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage) // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data .map(envelope -> envelope.toBuilder().clearSharedMrmKey().build()); - discardStaleMessages(destinationUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); + discardStaleMessages(destinationUuid, destinationDeviceId, staleMrmMessages, staleMrmMessagesCounter, "mrm"); return messagesToPublish.name(GET_FLUX_NAME) .tap(Micrometer.metrics(Metrics.globalRegistry)); } public Mono getEarliestUndeliveredTimestamp(final UUID destinationUuid, final byte destinationDevice) { - return getAllMessages(destinationUuid, destinationDevice, -1, 1) + return getAllMessages(destinationUuid, destinationDevice, -1, 1, true) .next() .map(MessageProtos.Envelope::getServerTimestamp); } @@ -380,18 +397,21 @@ public class MessagesCache { } @VisibleForTesting - Flux getAllMessages(final UUID destinationUuid, final byte destinationDevice, - final long earliestAllowableEphemeralTimestamp, final int pageSize) { + Flux getAllMessages(final UUID destinationUuid, + final byte destinationDevice, + final long earliestAllowableEphemeralTimestamp, + final int pageSize, + final boolean bypassLock) { // fetch messages by page - return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize) + return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize, bypassLock) .expand(queueItemsAndLastMessageId -> { // expand() is breadth-first, so each page will be published in order if (queueItemsAndLastMessageId.first().isEmpty()) { return Mono.empty(); } - return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize); + return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize, bypassLock); }) .limitRate(1) // we want to ensure we don’t accidentally block the Lettuce/netty i/o executors @@ -536,10 +556,13 @@ public class MessagesCache { .subscribe(); } - private Mono, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, - long messageId, int pageSize) { + private Mono, Long>> getNextMessagePage(final UUID destinationUuid, + final byte destinationDevice, + final long messageId, + final int pageSize, + final boolean bypassLock) { - return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId) + return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId, bypassLock) .map(queueItems -> { logger.trace("Processing page: {}", messageId); @@ -590,35 +613,6 @@ public class MessagesCache { .toFuture(); } - List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, - final int limit) { - - final Timer.Sample sample = Timer.start(); - - final List messages = redisCluster.withBinaryCluster(connection -> - connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit)); - - final Flux allMessages = parseAndFetchMrms(Flux.fromIterable(messages), destinationDevice); - - final Flux messagesToPersist = allMessages - .filter(Predicate.not(envelope -> - envelope.getEphemeral() || isStaleMrmMessage(envelope))); - - final Flux ephemeralMessages = allMessages - .filter(MessageProtos.Envelope::getEphemeral); - discardStaleMessages(accountUuid, destinationDevice, ephemeralMessages, staleEphemeralMessagesCounter, "ephemeral"); - - final Flux staleMrmMessages = allMessages.filter(MessagesCache::isStaleMrmMessage) - // clearing the sharedMrmKey prevents unnecessary calls to update the shared MRM data - .map(envelope -> envelope.toBuilder().clearSharedMrmKey().build()); - discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm"); - - return messagesToPersist - .collectList() - .doOnTerminate(() -> sample.stop(getMessagesTimer)) - .block(Duration.ofSeconds(5)); - } - private Flux parseAndFetchMrms(final Flux serializedMessages, final byte destinationDevice) { return serializedMessages .mapNotNull(message -> { @@ -720,13 +714,14 @@ public class MessagesCache { ScanArgs.Builder.matches("user_queue::*").limit(scanCount))); } - void lockQueueForPersistence(final UUID accountUuid, final byte deviceId) { - redisCluster.useBinaryCluster( - connection -> connection.sync().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE)); + Mono lockQueueForPersistence(final UUID accountUuid, final byte deviceId) { + return redisCluster.withBinaryCluster( + connection -> connection.reactive().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE)) + .then(); } - void unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) { - unlockQueueScript.execute(accountUuid, deviceId); + Mono unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) { + return unlockQueueScript.execute(accountUuid, deviceId); } static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java index e25d8e51e..f10d45263 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScript.java @@ -28,15 +28,20 @@ class MessagesCacheGetItemsScript { this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT); } - Mono> execute(final UUID destinationUuid, final byte destinationDevice, - int limit, long afterMessageId) { + Mono> execute(final UUID destinationUuid, + final byte destinationDevice, + final int limit, + final long afterMessageId, + final boolean bypassLock) { + final List keys = List.of( MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey ); final List args = List.of( String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit - String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId + String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8), // afterMessageId + String.valueOf(bypassLock).getBytes(StandardCharsets.UTF_8) // bypassLock ); //noinspection unchecked return getItemsScript.executeBinaryReactive(keys, args) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheUnlockQueueScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheUnlockQueueScript.java index 1be273840..342fd13a6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheUnlockQueueScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheUnlockQueueScript.java @@ -6,15 +6,17 @@ package org.whispersystems.textsecuregcm.storage; import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.UUID; import org.whispersystems.textsecuregcm.push.ClientEvent; import org.whispersystems.textsecuregcm.push.MessagesPersistedEvent; import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; -import org.whispersystems.textsecuregcm.util.ResilienceUtil; -import java.io.IOException; -import java.util.List; -import java.util.UUID; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; /** * Unlocks a message queue for persistence/message retrieval. @@ -33,13 +35,14 @@ class MessagesCacheUnlockQueueScript { ClusterLuaScript.fromResource(redisCluster, "lua/unlock_queue.lua", ScriptOutputType.STATUS); } - void execute(final UUID accountIdentifier, final byte deviceId) { + Mono execute(final UUID accountIdentifier, final byte deviceId) { final List keys = List.of( MessagesCache.getPersistInProgressKey(accountIdentifier, deviceId), // persistInProgressKey RedisMessageAvailabilityManager.getClientEventChannel(accountIdentifier, deviceId) // eventChannelKey ); - ResilienceUtil.getGeneralRedisRetry(MessagesCache.RETRY_NAME) - .executeRunnable(() -> unlockQueueScript.executeBinary(keys, MESSAGES_PERSISTED_EVENT_ARGS)); + return unlockQueueScript.executeBinaryReactive(keys, MESSAGES_PERSISTED_EVENT_ARGS) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .then(); } } diff --git a/service/src/main/resources/lua/get_items.lua b/service/src/main/resources/lua/get_items.lua index cb930b53b..65fa73f3e 100644 --- a/service/src/main/resources/lua/get_items.lua +++ b/service/src/main/resources/lua/get_items.lua @@ -2,14 +2,17 @@ -- returns a list of all envelopes and their queue-local IDs local queueKey = KEYS[1] -- sorted set of all Envelopes for a device, scored by queue-local ID -local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persistent and must not be read +local queueLockKey = KEYS[2] -- a key whose presence indicates that the queue is being persisted and must not be read local limit = ARGV[1] -- [number] the maximum number of messages to return local afterMessageId = ARGV[2] -- [number] a queue-local ID to exclusively start after, to support pagination. Use -1 to start at the beginning +local bypassLock = ARGV[3] -- [string] whether to bypass the persistence lock (i.e. when fetching messages for persistence) -local locked = redis.call("GET", queueLockKey) +if bypassLock ~= "true" then + local locked = redis.call("GET", queueLockKey) -if locked then - return {} + if locked then + return {} + end end if afterMessageId == "null" or afterMessageId == nil then diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 215d3cc23..178d0e877 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -218,7 +218,7 @@ class MessagePersisterIntegrationTest { } @Test - void testPersistFirstPageDiscarded() throws MessagePersistenceException { + void testPersistFirstPageDiscarded() { final int discardableMessages = MessagePersister.MESSAGE_BATCH_LIMIT * 2; final int persistableMessages = MessagePersister.MESSAGE_BATCH_LIMIT + 1; @@ -245,7 +245,7 @@ class MessagePersisterIntegrationTest { expectedMessages.add(message); } - messagePersister.persistQueue(account, account.getDevice(Device.PRIMARY_ID).orElseThrow(), Tags.empty()); + messagePersister.persistQueue(account, account.getDevice(Device.PRIMARY_ID).orElseThrow(), Tags.empty()).block(); final DynamoDbClient dynamoDB = DYNAMO_DB_EXTENSION.getDynamoDbClient(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 248f67b5f..137414767 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -64,6 +64,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; import reactor.util.retry.Retry; import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException; @@ -176,6 +177,25 @@ class MessagePersisterTest { resubscribeRetryExecutorService.awaitTermination(1, TimeUnit.SECONDS); } + @Test + void persistQueue() { + final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, + CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1))); + + messagePersister.persistQueue(destinationAccount, destinationAccount.getDevice(DESTINATION_DEVICE_ID).orElseThrow(), Tags.empty()) + .block(); + + @SuppressWarnings("unchecked") final ArgumentCaptor> messagesCaptor = + ArgumentCaptor.forClass(List.class); + + verify(messagesDynamoDb, atLeastOnce()) + .store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE)); + + assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); + } + @Test void persistNextNodeNoQueues() { assertEquals(0, messagePersister.persistNextNode()); @@ -407,39 +427,6 @@ class MessagePersisterTest { assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } - @Test - void persistNodePersistQueueMessagePersistenceException() { - final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, - CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1))); - - // Provoke a MessagePersistenceException - when(messagesManager.persistMessages(any(), any(), any())).thenReturn(0); - - assertEquals(0, messagePersister.persistNode( - getNodeWithKey(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID)))); - - // We use this as a proxy for attempts to persist messages; for a MessagePersistenceException, we should NOT retry, - // and this should happen exactly once - verify(messagesCache).lockQueueForPersistence(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID); - verify(messagesDynamoDb, never()).store(any(), any(), any()); - } - - @Test - void testPersistQueueRetryLoop() { - final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, CLOCK.instant().minus(PERSIST_DELAY.plusSeconds(1))); - - // returning `0` indicates something not working correctly - when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenReturn(0); - - assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - assertThrows(MessagePersistenceException.class, - () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()))); - } - @Test void testUnlinkOnFullQueue() { final int messageCount = 1; @@ -469,12 +456,15 @@ class MessagePersisterTest { final Device destination = mock(Device.class); when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID); - when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); + when(destinationAccount.getDevices()) + .thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination)); - when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); + when(messagesManager.persistMessages(any(), any(), any())) + .thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> - messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty())); + messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block()); + verify(accountsManager, exactly()).removeDevice(destinationAccount, DESTINATION_DEVICE_ID); } @@ -520,8 +510,9 @@ class MessagePersisterTest { when(messagesManager.delete(any(), any(), any(), anyLong())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - assertThrows(MessagePersistenceException.class, () -> - messagePersister.persistQueue(destinationAccount, primary, Tags.empty())); + StepVerifier.create(messagePersister.persistQueue(destinationAccount, primary, Tags.empty())) + .expectError(MessagePersistenceException.class) + .verify(); verify(messagesManager, times(expectedClearedGuids.size())) .delete(eq(DESTINATION_ACCOUNT_UUID), eq(primary), argThat(expectedClearedGuids::contains), anyLong()); @@ -563,7 +554,7 @@ class MessagePersisterTest { when(messagesManager.persistMessages(any(UUID.class), any(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(accountsManager.removeDevice(destinationAccount, DESTINATION_DEVICE_ID)).thenThrow(new RuntimeException()); - assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty())); + assertThrows(RuntimeException.class, () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE, Tags.empty()).block()); } private static RedisClusterNode getNodeWithKey(final byte[] key) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java index 61af1dc37..2aebeaf7b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java @@ -49,7 +49,7 @@ class MessagesCacheGetItemsScriptTest { final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript( REDIS_CLUSTER_EXTENSION.getRedisCluster()); - final List messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1) + final List messageAndScores = getItemsScript.execute(destinationUuid, deviceId, 1, -1, false) .block(Duration.ofSeconds(1)); assertNotNull(messageAndScores); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java index a0dcce8f7..73b0b732d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java @@ -68,7 +68,7 @@ class MessagesCacheInsertScriptTest { final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(REDIS_CLUSTER_EXTENSION.getRedisCluster()); - final List queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0) + final List queueItems = getItemsScript.execute(destinationUuid, deviceId, 1024, 0, false) .blockOptional() .orElseGet(Collections::emptyList); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index da68dc2d4..0df85ee19 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -139,7 +139,7 @@ class MessagesCacheTest { messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join(); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join(); - assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10) + assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10, false) .count() .blockOptional() .orElse(0L)); @@ -319,15 +319,41 @@ class MessagesCacheTest { expectedMessages.add(message); } - messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block(); assertTrue(get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount).isEmpty()); - messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block(); assertEquals(expectedMessages, get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } + @Test + void testGetMessagesToPersistLockedForPersistence() { + final int messageCount = 100; + + final List expectedMessages = new ArrayList<>(messageCount); + + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); + expectedMessages.add(message); + } + + messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block(); + + try { + assertEquals(expectedMessages, + Flux.from(messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID)) + .collectList() + .blockOptional() + .orElseThrow()); + } finally { + messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID).block(); + } + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void testGetMessagesPublisher(final boolean expectStale) throws Exception { @@ -385,7 +411,7 @@ class MessagesCacheTest { .get(5, TimeUnit.SECONDS); final List messages = messagesCache.getAllMessages(DESTINATION_UUID, - DESTINATION_DEVICE_ID, 0, 10) + DESTINATION_DEVICE_ID, 0, 10, false) .collectList() .toFuture().get(5, TimeUnit.SECONDS); @@ -684,7 +710,11 @@ class MessagesCacheTest { .build(); messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join(); - final List messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100); + final List messages = + Flux.from(messagesCache.getMessagesToPersist(destinationUuid, deviceId)) + .collectList() + .blockOptional() + .orElseThrow(); if (!sharedMrmKeyPresent) { assertEquals(1, messages.size()); @@ -778,7 +808,7 @@ class MessagesCacheTest { .thenReturn(Flux.from(emptyFinalPagePublisher)) .thenReturn(Flux.empty()); - final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10); + final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10, false); // Why initialValue = 3? // 1. messagesCache.getAllMessages() above produces the first call