Automatically trim primary queue when cache cannot be persisted

This commit is contained in:
Ravi Khadiwala
2025-02-28 11:11:42 -06:00
committed by ravi-signal
parent 8517eef3fe
commit 09b50383d7
6 changed files with 280 additions and 35 deletions

View File

@@ -6,13 +6,33 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
public class DynamicMessagePersisterConfiguration {
@JsonProperty
private boolean persistenceEnabled = true;
/**
* If we have to trim a client's persisted queue to make room to persist from Redis to DynamoDB, how much extra room should we make
*/
@JsonProperty
private double trimOversizedQueueExtraRoomRatio = 1.5;
public DynamicMessagePersisterConfiguration() {}
@VisibleForTesting
public DynamicMessagePersisterConfiguration(final boolean persistenceEnabled, final double trimOversizedQueueExtraRoomRatio) {
this.persistenceEnabled = persistenceEnabled;
this.trimOversizedQueueExtraRoomRatio = trimOversizedQueueExtraRoomRatio;
}
public boolean isPersistenceEnabled() {
return persistenceEnabled;
}
public double getTrimOversizedQueueExtraRoomRatio() {
return trimOversizedQueueExtraRoomRatio;
}
}

View File

@@ -18,11 +18,17 @@ import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
public class MessagePersister implements Managed {
@@ -30,17 +36,21 @@ public class MessagePersister implements Managed {
private final MessagesCache messagesCache;
private final MessagesManager messagesManager;
private final AccountsManager accountsManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final Duration persistDelay;
private final Thread[] workerThreads;
private volatile boolean running;
private static final String OVERSIZED_QUEUE_COUNTER_NAME = name(MessagePersister.class, "persistQueueOversized");
private final Timer getQueuesTimer = Metrics.timer(name(MessagePersister.class, "getQueues"));
private final Timer persistQueueTimer = Metrics.timer(name(MessagePersister.class, "persistQueue"));
private final Counter persistQueueExceptionMeter = Metrics.counter(
name(MessagePersister.class, "persistQueueException"));
private final Counter oversizedQueueCounter = Metrics.counter(name(MessagePersister.class, "persistQueueOversized"));
private static final Counter trimmedMessageCounter = Metrics.counter(name(MessagePersister.class, "trimmedMessage"));
private static final Counter trimmedMessageBytesCounter = Metrics.counter(name(MessagePersister.class, "trimmedMessageBytes"));
private final DistributionSummary queueCountDistributionSummery = DistributionSummary.builder(
name(MessagePersister.class, "queueCount"))
.publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999)
@@ -54,6 +64,7 @@ public class MessagePersister implements Managed {
static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100;
static final int CACHE_PAGE_SIZE = 100;
private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis();
@@ -66,12 +77,12 @@ public class MessagePersister implements Managed {
final AccountsManager accountsManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final Duration persistDelay,
final int dedicatedProcessWorkerThreadCount
) {
final int dedicatedProcessWorkerThreadCount) {
this.messagesCache = messagesCache;
this.messagesManager = messagesManager;
this.accountsManager = accountsManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.persistDelay = persistDelay;
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
@@ -159,7 +170,10 @@ public class MessagePersister implements Managed {
messagesCache.addQueueToPersist(accountUuid, deviceId);
Util.sleep(EXCEPTION_PAUSE_MILLIS);
if (!(e instanceof MessagePersistenceException)) {
// Pause after unexpected exceptions
Util.sleep(EXCEPTION_PAUSE_MILLIS);
}
}
}
@@ -204,8 +218,19 @@ public class MessagePersister implements Managed {
queueSizeDistributionSummery.record(messageCount);
} catch (ItemCollectionSizeLimitExceededException e) {
oversizedQueueCounter.increment();
maybeUnlink(account, deviceId); // may throw, in which case we'll retry later by the usual mechanism
final boolean isPrimary = deviceId == Device.PRIMARY_ID;
Metrics.counter(OVERSIZED_QUEUE_COUNTER_NAME, "primary", String.valueOf(isPrimary)).increment();
// may throw, in which case we'll retry later by the usual mechanism
if (isPrimary) {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will trim oldest messages",
account.getUuid(), deviceId);
trimQueue(account, deviceId);
throw new MessagePersistenceException("Could not persist due to an overfull queue. Trimmed primary queue, a subsequent retry may succeed");
} else {
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(),
deviceId);
accountsManager.removeDevice(account, deviceId).join();
}
} finally {
messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
sample.stop(persistQueueTimer);
@@ -213,13 +238,62 @@ public class MessagePersister implements Managed {
}
@VisibleForTesting
void maybeUnlink(final Account account, byte destinationDeviceId) throws MessagePersistenceException {
if (destinationDeviceId == Device.PRIMARY_ID) {
throw new MessagePersistenceException("primary device has a full queue");
}
private void trimQueue(final Account account, byte deviceId) {
final UUID aci = account.getIdentifier(IdentityType.ACI);
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device", account.getUuid(), destinationDeviceId);
accountsManager.removeDevice(account, destinationDeviceId).join();
final Optional<Device> maybeDevice = account.getDevice(deviceId);
if (maybeDevice.isEmpty()) {
logger.warn("Not deleting messages for overfull queue {}::{}, deviceId {} does not exist",
aci, deviceId, deviceId);
return;
}
final Device device = maybeDevice.get();
// Calculate how many bytes we should trim
final long cachedMessageBytes = Flux
.from(messagesCache.getMessagesToPersistReactive(aci, deviceId, CACHE_PAGE_SIZE))
.reduce(0, (acc, envelope) -> acc + envelope.getSerializedSize())
.block();
final double extraRoomRatio = this.dynamicConfigurationManager.getConfiguration()
.getMessagePersisterConfiguration()
.getTrimOversizedQueueExtraRoomRatio();
final long targetDeleteBytes = Math.round(cachedMessageBytes * extraRoomRatio);
final AtomicLong oldestMessage = new AtomicLong(0L);
final AtomicLong newestMessage = new AtomicLong(0L);
final AtomicLong bytesDeleted = new AtomicLong(0L);
// Iterate from the oldest message until we've removed targetDeleteBytes
final Pair<Long, Long> outcomes = Flux.from(messagesManager.getMessagesForDeviceReactive(aci, device, false))
.concatMap(envelope -> {
if (bytesDeleted.getAndAdd(envelope.getSerializedSize()) >= targetDeleteBytes) {
return Mono.just(Optional.<MessageProtos.Envelope>empty());
}
oldestMessage.compareAndSet(0L, envelope.getServerTimestamp());
newestMessage.set(envelope.getServerTimestamp());
return Mono.just(Optional.of(envelope));
})
.takeWhile(Optional::isPresent)
.flatMap(maybeEnvelope -> {
final MessageProtos.Envelope envelope = maybeEnvelope.get();
trimmedMessageCounter.increment();
trimmedMessageBytesCounter.increment(envelope.getSerializedSize());
return Mono
.fromCompletionStage(() -> messagesManager
.delete(aci, device, UUID.fromString(envelope.getServerGuid()), envelope.getServerTimestamp()))
.retryWhen(Retry.backoff(5, Duration.ofSeconds(1)))
.map(Optional::isPresent);
})
.reduce(Pair.of(0L, 0L), (acc, deleted) -> deleted
? Pair.of(acc.getLeft() + 1, acc.getRight())
: Pair.of(acc.getLeft(), acc.getRight() + 1))
.block();
logger.warn(
"Finished trimming {}:{}. Oldest message = {}, newest message = {}. Attempted to delete {} persisted bytes to make room for {} cached message bytes. Delete outcomes: {} present, {} missing.",
aci, deviceId,
Instant.ofEpochMilli(oldestMessage.get()), Instant.ofEpochMilli(newestMessage.get()),
targetDeleteBytes, cachedMessageBytes,
outcomes.getLeft(), outcomes.getRight());
}
}

View File

@@ -10,6 +10,9 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.Limit;
import io.lettuce.core.Range;
import io.lettuce.core.ScoredValue;
import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Counter;
@@ -31,6 +34,7 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.Predicate;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
@@ -528,15 +532,28 @@ public class MessagesCache {
});
}
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
Flux<MessageProtos.Envelope> getMessagesToPersistReactive(final UUID accountUuid, final byte destinationDevice,
final int pageSize) {
final Timer.Sample sample = Timer.start();
final List<byte[]> messages = redisCluster.withBinaryCluster(connection ->
connection.sync().zrange(getMessageQueueKey(accountUuid, destinationDevice), 0, limit));
final Function<Long, Mono<List<ScoredValue<byte[]>>>> getNextPage = (Long start) ->
Mono.fromCompletionStage(() -> redisCluster.withBinaryCluster(connection ->
connection.async().zrangebyscoreWithScores(
getMessageQueueKey(accountUuid, destinationDevice),
Range.from(
Range.Boundary.excluding(start),
Range.Boundary.unbounded()),
Limit.from(pageSize))));
final Flux<MessageProtos.Envelope> allMessages = Flux.fromIterable(messages)
final Flux<MessageProtos.Envelope> allMessages = getNextPage.apply(0L)
.expand(scoredValues -> {
if (scoredValues.isEmpty()) {
return Mono.empty();
}
long lastTimestamp = (long) scoredValues.getLast().getScore();
return getNextPage.apply(lastTimestamp);
})
.concatMap(scoredValues -> Flux.fromStream(scoredValues.stream().map(ScoredValue::getValue)))
.mapNotNull(message -> {
try {
return MessageProtos.Envelope.parseFrom(message);
@@ -554,7 +571,15 @@ public class MessagesCache {
}
return messageMono;
});
})
.publish()
// We expect exactly three subscribers to this base flux:
// 1. the caller of the method
// 2. an internal processes to discard stale ephemeral messages
// 3. an internal process to discard stale MRM messages
// The discard subscribers will subscribe immediately, but we dont want to do any work if the
// caller never subscribes
.autoConnect(3);
final Flux<MessageProtos.Envelope> messagesToPersist = allMessages
.filter(Predicate.not(envelope ->
@@ -570,8 +595,14 @@ public class MessagesCache {
discardStaleMessages(accountUuid, destinationDevice, staleMrmMessages, staleMrmMessagesCounter, "mrm");
return messagesToPersist
.doOnTerminate(() -> sample.stop(getMessagesTimer));
}
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
return getMessagesToPersistReactive(accountUuid, destinationDevice, limit)
.take(limit)
.collectList()
.doOnTerminate(() -> sample.stop(getMessagesTimer))
.block(Duration.ofSeconds(5));
}

View File

@@ -11,7 +11,6 @@ import static io.micrometer.core.instrument.Metrics.timer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.InvalidProtocolBufferException;
import io.micrometer.core.instrument.Timer;
import java.nio.ByteBuffer;
import java.time.Duration;
@@ -24,7 +23,6 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;