Multi-recipient message views

This adds support for storing multi-recipient message payloads and recipient views in Redis, and only fanning out on delivery or persistence. Phase 1: confirm storage and retrieval correctness.
This commit is contained in:
Chris Eager
2024-09-04 13:58:20 -05:00
committed by GitHub
parent d78c8370b6
commit 11601fd091
50 changed files with 1544 additions and 328 deletions

View File

@@ -632,7 +632,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
keyspaceNotificationDispatchExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock);
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases,
recurringJobExecutor,
config.getClientReleaseConfiguration().refreshInterval(),

View File

@@ -27,5 +27,4 @@ public class MessageCacheConfiguration {
public int getPersistDelayMinutes() {
return persistDelayMinutes;
}
}

View File

@@ -5,21 +5,9 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import java.util.List;
import javax.validation.constraints.NotNull;
public record DynamicMessagesConfiguration(@NotNull List<DynamoKeyScheme> dynamoKeySchemes) {
public enum DynamoKeyScheme {
TRADITIONAL,
LAZY_DELETION;
}
public record DynamicMessagesConfiguration(boolean storeSharedMrmData, boolean mrmViewExperimentEnabled) {
public DynamicMessagesConfiguration() {
this(List.of(DynamoKeyScheme.TRADITIONAL));
}
public DynamoKeyScheme writeKeyScheme() {
return dynamoKeySchemes().getLast();
this(false, false);
}
}

View File

@@ -24,7 +24,6 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.security.MessageDigest;
import java.time.Clock;
import java.time.Duration;
@@ -73,8 +72,8 @@ import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
@@ -261,7 +260,7 @@ public class MessageController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
@Operation(
@Operation(
summary = "Send a message",
description = """
Deliver a message to a single recipient. May be authenticated or unauthenticated; if unauthenticated,
@@ -309,9 +308,10 @@ public class MessageController {
if (groupSendToken != null) {
if (!source.isEmpty() || !accessKey.isEmpty()) {
throw new BadRequestException("Group send endorsement tokens should not be combined with other authentication");
throw new BadRequestException(
"Group send endorsement tokens should not be combined with other authentication");
} else if (isStory) {
throw new BadRequestException("Group send endorsement tokens should not be sent for story messages");
throw new BadRequestException("Group send endorsement tokens should not be sent for story messages");
}
}
@@ -346,8 +346,7 @@ public class MessageController {
}
final Optional<byte[]> spamReportToken = switch (senderType) {
case SENDER_TYPE_IDENTIFIED ->
reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
case SENDER_TYPE_IDENTIFIED -> reportSpamTokenProvider.makeReportSpamToken(context, source.get(), destination);
default -> Optional.empty();
};
@@ -470,7 +469,7 @@ public class MessageController {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
@@ -621,27 +620,28 @@ public class MessageController {
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account();
final Account account = recipient.account();
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet());
try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
DestinationDeviceValidator.validateRegistrationIds(
account,
recipient.deviceIdToRegistrationId().entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
recipient.serviceIdentifier(),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
@@ -667,6 +667,11 @@ public class MessageController {
}
try {
@Nullable final byte[] sharedMrmKey =
dynamicConfigurationManager.getConfiguration().getMessagesConfiguration().storeSharedMrmData()
? messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
: null;
CompletableFuture.allOf(
recipients.values().stream()
.flatMap(recipientData -> {
@@ -692,8 +697,7 @@ public class MessageController {
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp,
online,
isStory, isUrgent, payload);
online, isStory, isUrgent, payload, sharedMrmKey);
},
multiRecipientMessageExecutor));
})
@@ -739,8 +743,8 @@ public class MessageController {
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account ->
account.getUnidentifiedAccessKey()
.filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(new byte[keyLength],
(a, b) -> {
final byte[] xor = new byte[keyLength];
@@ -828,23 +832,28 @@ public class MessageController {
auth.getAuthenticatedDevice(),
uuid,
null)
.thenAccept(maybeDeletedMessage -> {
maybeDeletedMessage.ifPresent(deletedMessage -> {
.thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getServerTimestamp(),
auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(),
auth.getAuthenticatedDevice());
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().isPresent()
&& removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) {
try {
receiptSender.sendReceipt(
ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(),
AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(),
aciServiceIdentifier, removedMessage.clientTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
}
} else {
// If source service ID is present and the envelope type is not a server delivery receipt, then
// the source service ID *should always* be an ACI -- PNIs are receive-only, so they can only be the
// "source" via server delivery receipts
logger.warn("Source service ID unexpectedly a PNI service ID");
}
});
})
}
}))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@@ -943,19 +952,25 @@ public class MessageController {
boolean online,
boolean story,
boolean urgent,
byte[] payload) {
byte[] payload,
@Nullable byte[] sharedMrmKey) {
final Envelope.Builder messageBuilder = Envelope.newBuilder();
final long serverTimestamp = System.currentTimeMillis();
messageBuilder
.setType(Type.UNIDENTIFIED_SENDER)
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFrom(payload))
.setStory(story)
.setUrgent(urgent)
.setDestinationUuid(serviceIdentifier.toServiceIdentifierString());
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString());
if (sharedMrmKey != null) {
messageBuilder.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
}
// mrm views phase 1: always set content
messageBuilder.setContent(ByteString.copyFrom(payload));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
}

View File

@@ -31,15 +31,15 @@ public record IncomingMessage(int type, byte destinationDeviceId, int destinatio
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder();
envelopeBuilder.setType(envelopeType)
.setTimestamp(timestamp)
.setClientTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis())
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setStory(story)
.setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) {
envelopeBuilder
.setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue());
}

View File

@@ -40,15 +40,15 @@ public record OutgoingMessageEntity(UUID guid,
public MessageProtos.Envelope toEnvelope() {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type()))
.setTimestamp(timestamp())
.setClientTimestamp(timestamp())
.setServerTimestamp(serverTimestamp())
.setDestinationUuid(destinationUuid().toServiceIdentifierString())
.setDestinationServiceId(destinationUuid().toServiceIdentifierString())
.setServerGuid(guid().toString())
.setStory(story)
.setUrgent(urgent);
if (sourceUuid() != null) {
builder.setSourceUuid(sourceUuid().toServiceIdentifierString());
builder.setSourceServiceId(sourceUuid().toServiceIdentifierString());
builder.setSourceDevice(sourceDevice());
}
@@ -72,10 +72,10 @@ public record OutgoingMessageEntity(UUID guid,
return new OutgoingMessageEntity(
UUID.fromString(envelope.getServerGuid()),
envelope.getType().getNumber(),
envelope.getTimestamp(),
envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null,
envelope.getClientTimestamp(),
envelope.hasSourceServiceId() ? ServiceIdentifier.valueOf(envelope.getSourceServiceId()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null,
envelope.hasDestinationServiceId() ? ServiceIdentifier.valueOf(envelope.getDestinationServiceId()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.getContent().toByteArray(),
envelope.getServerTimestamp(),

View File

@@ -50,11 +50,11 @@ public final class MessageMetrics {
public void measureAccountEnvelopeUuidMismatches(final Account account,
final MessageProtos.Envelope envelope) {
if (envelope.hasDestinationUuid()) {
if (envelope.hasDestinationServiceId()) {
try {
measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid()));
measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationServiceId()));
} catch (final IllegalArgumentException ignored) {
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid());
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationServiceId());
}
}
}

View File

@@ -92,7 +92,7 @@ public class MessageSender {
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceUuid()))
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment();
}
}

View File

@@ -45,10 +45,10 @@ public class ReceiptSender {
destinationAccount -> {
final Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis())
.setSourceUuid(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice((int) sourceDeviceId)
.setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setTimestamp(messageId)
.setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice(sourceDeviceId)
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setClientTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false);

View File

@@ -138,12 +138,13 @@ public class ChangeNumberManager {
final long serverTimestamp = System.currentTimeMillis();
final Envelope envelope = Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setTimestamp(serverTimestamp)
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setDestinationServiceId(
new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(contents.get()))
.setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice((int) Device.PRIMARY_ID)
.setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.build();

View File

@@ -8,10 +8,10 @@ package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ScoredValue;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.ZAddArgs;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
@@ -20,6 +20,7 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
@@ -38,14 +39,17 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessagesConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Pair;
@@ -57,6 +61,62 @@ import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
/**
* Manages short-term storage of messages in Redis. Messages are frequently delivered to their destination and deleted
* shortly after they reach the server, and this cache acts as a low-latency holding area for new messages, reducing
* load on higher-latency, longer-term storage systems. Redis in particular provides keyspace notifications, which act
* as a form of pub-sub notifications to alert listeners when new messages arrive.
* <p>
* The following structures are used:
* <dl>
* <dt>{@code queueKey}</code></dt>
* <dd>A sorted set of messages in a devices queue. A messages score is its queue-local message ID. See
* <a href="https://redis.io/docs/latest/develop/use/patterns/twitter-clone/#the-sorted-set-data-type">Redis.io: The
* Sorted Set data type</a> for background on scores and this data structure.</dd>
* <dt>{@code queueMetadataKey}</dt>
* <dd>A hash containing message guids and their queue-local message ID. It also contains a {@code counter} key, which is
* incremented to supply the next message ID. This is used to remove a message by GUID from {@code queueKey} by its
* local messageId.</dd>
* <dt>{@code sharedMrmKey}</dt>
* <dd>A hash containing a single multi-recipient message pending delivery. It contains:
* <ul>
* <li>{@code data} - the serialized SealedSenderMultiRecipientMessage data</li>
* <li>fields with each recipient device's “view” into the payload ({@link SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)}</li>
* </ul>
* Note: this is shared among all of the message's recipients, and it may be located in any Redis shard. As each recipients
* message is delivered, its corresponding view is idempotently removed. When {@code data} is the only remaining
* field, the hash will be deleted.
* </dd>
* <dt>{@code queueLockKey}</dt>
* <dd>Used to indicate that a queue is being modified by the {@link MessagePersister} and that {@code get_items} should
* return an empty list.</dd>
* <dt>{@code queueTotalIndexKey}</dt>
* <dd>A sorted set of all queues in a shard. A queues score is the timestamp of its oldest message, which is used by
* the {@link MessagePersister} to prioritize queues to persist.</dd>
* </dl>
* <p>
* At a high level, the process is:
* <ol>
* <li>Insert: the queue metadata is queried for the next incremented message ID. The message data is inserted into
* the queue at that ID, and the message GUID is inserted in the queue metadata.</li>
* <li>Get: a batch of messages are retrieved from the queue, potentially with an after-message-ID offset.</li>
* <li>Remove: a set of messages are remove by GUID. For each GUID, the message ID is retrieved from the queue metadata,
* and then that single-value range is removed from the queue.</li>
* </ol>
* For multi-recipient messages (sometimes abbreviated “MRM”), there are similar operations on the common data during
* insert, get, and remove. MRM inserts must occur before individual queue inserts, while removal is considered
* best-effort, and uses key expiration as back-stop garbage collection.
* <p>
* For atomicity, many operations are implemented as Lua scripts that are executed on the Redis server using
* {@code EVAL}/{@code EVALSHA}.
*
* @see MessagesCacheInsertScript
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
* @see MessagesCacheGetItemsScript
* @see MessagesCacheRemoveByGuidScript
* @see MessagesCacheRemoveRecipientViewFromMrmDataScript
* @see MessagesCacheRemoveQueueScript
*/
public class MessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed {
private final FaultTolerantRedisCluster redisCluster;
@@ -69,17 +129,22 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// messageDeletionExecutorService wrapped into a reactor Scheduler
private final Scheduler messageDeletionScheduler;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private final ClusterLuaScript getQueuesToPersistScript;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final MessagesCacheInsertScript insertScript;
private final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript;
private final MessagesCacheRemoveByGuidScript removeByGuidScript;
private final MessagesCacheGetItemsScript getItemsScript;
private final MessagesCacheRemoveQueueScript removeQueueScript;
private final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript;
private final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript;
private final ReentrantLock messageListenersLock = new ReentrantLock();
private final Map<String, MessageAvailabilityListener> messageListenersByQueueName = new HashMap<>();
private final Map<MessageAvailabilityListener, String> queueNamesByMessageListener = new IdentityHashMap<>();
private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"));
private final Timer insertSharedMrmPayloadTimer = Metrics.timer(name(MessagesCache.class, "insertSharedMrmPayload"));
private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get"));
private final Timer getQueuesToPersistTimer = Metrics.timer(name(MessagesCache.class, "getQueuesToPersist"));
private final Timer removeByGuidTimer = Metrics.timer(name(MessagesCache.class, "removeByGuid"));
@@ -95,6 +160,9 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd"));
private final Counter prunedStaleSubscriptionCounter = Metrics.counter(
name(MessagesCache.class, "prunedStaleSubscription"));
private final Counter mrmContentRetrievedCounter = Metrics.counter(name(MessagesCache.class, "mrmViewRetrieved"));
private final Counter sharedMrmDataKeyRemovedCounter = Metrics.counter(
name(MessagesCache.class, "sharedMrmKeyRemoved"));
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
@@ -102,16 +170,49 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::";
private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::";
private static final String MRM_VIEWS_EXPERIMENT_NAME = "mrmViews";
@VisibleForTesting
static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10);
private static final String GET_FLUX_NAME = MetricsUtil.name(MessagesCache.class, "get");
private static final int PAGE_SIZE = 100;
private static final int REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY = 8;
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
public MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock)
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager)
throws IOException {
this(
redisCluster,
notificationExecutorService,
messageDeliveryScheduler,
messageDeletionExecutorService,
clock,
dynamicConfigurationManager,
new MessagesCacheInsertScript(redisCluster),
new MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(redisCluster),
new MessagesCacheGetItemsScript(redisCluster),
new MessagesCacheRemoveByGuidScript(redisCluster),
new MessagesCacheRemoveQueueScript(redisCluster),
new MessagesCacheGetQueuesToPersistScript(redisCluster),
new MessagesCacheRemoveRecipientViewFromMrmDataScript(redisCluster)
);
}
@VisibleForTesting
MessagesCache(final FaultTolerantRedisCluster redisCluster, final ExecutorService notificationExecutorService,
final Scheduler messageDeliveryScheduler, final ExecutorService messageDeletionExecutorService, final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final MessagesCacheInsertScript insertScript,
final MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript insertMrmScript,
final MessagesCacheGetItemsScript getItemsScript, final MessagesCacheRemoveByGuidScript removeByGuidScript,
final MessagesCacheRemoveQueueScript removeQueueScript,
final MessagesCacheGetQueuesToPersistScript getQueuesToPersistScript,
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript)
throws IOException {
this.redisCluster = redisCluster;
@@ -123,14 +224,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
this.messageDeletionExecutorService = messageDeletionExecutorService;
this.messageDeletionScheduler = Schedulers.fromExecutorService(messageDeletionExecutorService, "messageDeletion");
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.MULTI);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua",
ScriptOutputType.STATUS);
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua",
ScriptOutputType.MULTI);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.insertScript = insertScript;
this.insertMrmScript = insertMrmScript;
this.removeByGuidScript = removeByGuidScript;
this.getItemsScript = getItemsScript;
this.removeQueueScript = removeQueueScript;
this.getQueuesToPersistScript = getQueuesToPersistScript;
this.removeRecipientViewFromMrmDataScript = removeRecipientViewFromMrmDataScript;
}
@Override
@@ -164,51 +266,51 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice,
final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
return (long) insertTimer.record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getServerTimestamp()).getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8))));
return insertTimer.record(() -> insertScript.execute(destinationUuid, destinationDevice, messageWithGuid));
}
public CompletableFuture<Optional<MessageProtos.Envelope>> remove(final UUID destinationUuid,
public byte[] insertSharedMultiRecipientMessagePayload(UUID mrmGuid,
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
final byte[] sharedMrmKey = getSharedMrmKey(mrmGuid);
insertSharedMrmPayloadTimer.record(() -> insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage));
return sharedMrmKey;
}
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid))
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.get(0)));
.thenApply(removed -> removed.isEmpty() ? Optional.empty() : Optional.of(removed.getFirst()));
}
@SuppressWarnings("unchecked")
public CompletableFuture<List<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final byte destinationDevice,
final List<UUID> messageGuids) {
public CompletableFuture<List<RemovedMessage>> remove(final UUID destinationUuid,
final byte destinationDevice, final List<UUID> messageGuids) {
final Timer.Sample sample = Timer.start();
return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList()))
.thenApplyAsync(result -> {
List<byte[]> serialized = (List<byte[]>) result;
return removeByGuidScript.execute(destinationUuid, destinationDevice, messageGuids)
.thenApplyAsync(serialized -> {
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
final List<RemovedMessage> removedMessages = new ArrayList<>(serialized.size());
final List<byte[]> sharedMrmKeysToUpdate = new ArrayList<>();
for (final byte[] bytes : serialized) {
try {
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
final MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(bytes);
removedMessages.add(RemovedMessage.fromEnvelope(envelope));
if (envelope.hasSharedMrmKey()) {
sharedMrmKeysToUpdate.add(envelope.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
removeRecipientViewFromMrmData(sharedMrmKeysToUpdate, destinationUuid, destinationDevice);
return removedMessages;
}, messageDeletionExecutorService)
.whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
}, messageDeletionExecutorService).whenComplete((ignored, throwable) -> sample.stop(removeByGuidTimer));
}
public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) {
@@ -251,7 +353,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) {
return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp;
return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp;
}
private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice,
@@ -283,37 +385,101 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors
.publishOn(messageDeliveryScheduler)
.map(Pair::first)
.flatMapIterable(queueItems -> {
final List<MessageProtos.Envelope> envelopes = new ArrayList<>(queueItems.size() / 2);
.concatMap(queueItems -> {
final List<Mono<MessageProtos.Envelope>> envelopes = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
envelopes.add(message);
final Mono<MessageProtos.Envelope> messageMono;
if (message.hasSharedMrmKey()) {
maybeRunMrmViewExperiment(message, destinationUuid, destinationDevice);
// mrm views phase 1: messageMono for sharedMrmKey is always Mono.just(), because messages always have content
messageMono = Mono.just(message.toBuilder().clearSharedMrmKey().build());
} else {
messageMono = Mono.just(message);
}
envelopes.add(messageMono);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return envelopes;
return Flux.mergeSequential(envelopes);
});
}
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
/**
* Runs the fetch and compare logic for the MRM view experiment, if it is enabled.
*
* @see DynamicMessagesConfiguration#mrmViewExperimentEnabled()
*/
private void maybeRunMrmViewExperiment(final MessageProtos.Envelope mrmMessage, final UUID destinationUuid,
final byte destinationDevice) {
if (dynamicConfigurationManager.getConfiguration().getMessagesConfiguration()
.mrmViewExperimentEnabled()) {
final Experiment experiment = new Experiment(MRM_VIEWS_EXPERIMENT_NAME);
final byte[] key = mrmMessage.getSharedMrmKey().toByteArray();
final byte[] sharedMrmViewKey = MessagesCache.getSharedMrmViewKey(
new AciServiceIdentifier(destinationUuid), destinationDevice);
final Mono<MessageProtos.Envelope> mrmMessageMono = Mono.from(redisCluster.withBinaryClusterReactive(
conn -> conn.reactive().hmget(key, "data".getBytes(StandardCharsets.UTF_8), sharedMrmViewKey)
.collectList()
.publishOn(messageDeliveryScheduler)
.handle((mrmDataAndView, sink) -> {
try {
assert mrmDataAndView.size() == 2;
final byte[] content = SealedSenderMultiRecipientMessage.messageForRecipient(
mrmDataAndView.getFirst().getValue(),
mrmDataAndView.getLast().getValue());
sink.next(mrmMessage.toBuilder()
.clearSharedMrmKey()
.setContent(ByteString.copyFrom(content))
.build());
mrmContentRetrievedCounter.increment();
} catch (Exception e) {
sink.error(e);
}
})));
experiment.compareFutureResult(mrmMessage.toBuilder().clearSharedMrmKey().build(),
mrmMessageMono.toFuture());
}
}
/**
* Makes a best-effort attempt at asynchronously updating (and removing when empty) the MRM data structure
*/
private void removeRecipientViewFromMrmData(final List<byte[]> sharedMrmKeys, final UUID accountUuid,
final byte deviceId) {
Flux.fromIterable(sharedMrmKeys)
.collectMultimap(SlotHash::getSlot)
.flatMapMany(slotsAndKeys -> Flux.fromIterable(slotsAndKeys.values()))
.flatMap(
keys -> removeRecipientViewFromMrmDataScript.execute(keys, new AciServiceIdentifier(accountUuid), deviceId),
REMOVE_MRM_RECIPIENT_VIEW_CONCURRENCY)
.subscribe(sharedMrmDataKeyRemovedCounter::increment, e -> logger.warn("Error removing recipient view", e));
}
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid,
final byte destinationDevice,
long messageId) {
return getItemsScript.executeBinaryReactive(
List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8)))
.map(result -> {
return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId)
.map(queueItems -> {
logger.trace("Processing page: {}", messageId);
@SuppressWarnings("unchecked")
List<byte[]> queueItems = (List<byte[]>) result;
if (queueItems.isEmpty()) {
return new Pair<>(Collections.emptyList(), null);
}
@@ -324,7 +490,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
final long lastMessageId = Long.parseLong(
new String(queueItems.get(queueItems.size() - 1), StandardCharsets.UTF_8));
new String(queueItems.getLast(), StandardCharsets.UTF_8));
return new Pair<>(queueItems, lastMessageId);
});
@@ -362,10 +528,35 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
public CompletableFuture<Void> clear(final UUID destinationUuid, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId),
getMessageQueueMetadataKey(destinationUuid, deviceId),
getQueueIndexKey(destinationUuid, deviceId)),
Collections.emptyList())
return removeQueueScript.execute(destinationUuid, deviceId, Collections.emptyList())
.publishOn(messageDeletionScheduler)
.expand(messagesToProcess -> {
if (messagesToProcess.isEmpty()) {
return Mono.empty();
}
final List<byte[]> mrmKeys = new ArrayList<>(messagesToProcess.size());
final List<String> processedMessages = new ArrayList<>(messagesToProcess.size());
for (byte[] serialized : messagesToProcess) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(serialized);
processedMessages.add(message.getServerGuid());
if (message.hasSharedMrmKey()) {
mrmKeys.add(message.getSharedMrmKey().toByteArray());
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
removeRecipientViewFromMrmData(mrmKeys, destinationUuid, deviceId);
return removeQueueScript.execute(destinationUuid, deviceId, processedMessages);
})
.then()
.toFuture()
.thenRun(() -> sample.stop(clearQueueTimer));
}
@@ -375,11 +566,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
//noinspection unchecked
return getQueuesToPersistTimer.record(() -> (List<String>) getQueuesToPersistScript.execute(
List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
List.of(String.valueOf(maxTime.toEpochMilli()),
String.valueOf(limit))));
return getQueuesToPersistTimer.record(() -> getQueuesToPersistScript.execute(slot, maxTime, limit));
}
void addQueueToPersist(final UUID accountUuid, final byte deviceId) {
@@ -538,29 +725,36 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return channel.substring(startOfHashTag + 1, endOfHashTag);
}
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
}
private static byte[] getQueueIndexKey(final int slot) {
static byte[] getQueueIndexKey(final int slot) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
static byte[] getSharedMrmKey(final UUID mrmGuid) {
return ("mrm::{" + mrmGuid.toString() + "}").getBytes(StandardCharsets.UTF_8);
}
static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getUnlinkInProgressKey(final UUID accountUuid) {
return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8);
static byte[] getSharedMrmViewKey(final AciServiceIdentifier serviceIdentifier, final byte deviceId) {
final ByteBuffer keyBb = ByteBuffer.allocate(18);
keyBb.put(serviceIdentifier.toFixedWidthByteArray());
keyBb.put(deviceId);
assert !keyBb.hasRemaining();
return keyBb.array();
}
static UUID getAccountUuidFromQueueName(final String queueName) {

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Retrieves a list of messages and their corresponding queue-local IDs for the device. To support streaming processing,
* the last queue-local message ID from a previous call may be used as the {@code afterMessageId}.
*/
class MessagesCacheGetItemsScript {
private final ClusterLuaScript getItemsScript;
MessagesCacheGetItemsScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.OBJECT);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
int limit, long afterMessageId) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getPersistInProgressKey(destinationUuid, destinationDevice) // queueLockKey
);
final List<byte[]> args = List.of(
String.valueOf(limit).getBytes(StandardCharsets.UTF_8), // limit
String.valueOf(afterMessageId).getBytes(StandardCharsets.UTF_8) // afterMessageId
);
//noinspection unchecked
return getItemsScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Returns a list of queues that may be persisted. They will be sorted from oldest to more recent, limited by the
* {@code maxTime} argument.
*
* @see MessagePersister
*/
class MessagesCacheGetQueuesToPersistScript {
private final ClusterLuaScript getQueuesToPersistScript;
MessagesCacheGetQueuesToPersistScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua",
ScriptOutputType.MULTI);
}
List<String> execute(final int slot, final Instant maxTime, final int limit) {
final List<String> keys = List.of(
new String(MessagesCache.getQueueIndexKey(slot), StandardCharsets.UTF_8) // queueTotalIndexKey
);
final List<String> args = List.of(
String.valueOf(maxTime.toEpochMilli()), // maxTime
String.valueOf(limit) // limit
);
//noinspection unchecked
return (List<String>) getQueuesToPersistScript.execute(keys, args);
}
}

View File

@@ -0,0 +1,48 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts an envelope into the message queue for a destination device.
*/
class MessagesCacheInsertScript {
private final ClusterLuaScript insertScript;
MessagesCacheInsertScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
}
long execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp();
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>(Arrays.asList(
envelope.toByteArray(), // message
String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid
));
return (long) insertScript.executeBinary(keys, args);
}
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
* fields in the hash.
*
* @see SealedSenderMultiRecipientMessage#serializedRecipientView(SealedSenderMultiRecipientMessage.Recipient)
*/
class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
private final ClusterLuaScript script;
MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript(FaultTolerantRedisCluster redisCluster)
throws IOException {
this.script = ClusterLuaScript.fromResource(redisCluster, "lua/insert_shared_multirecipient_message_data.lua",
ScriptOutputType.INTEGER);
}
void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
final List<byte[]> keys = List.of(
sharedMrmKey // sharedMrmKey
);
// Pre-allocate capacity for the most fields we expect -- 6 devices per recipient, plus the data field.
final List<byte[]> args = new ArrayList<>(message.getRecipients().size() * 6 + 1);
args.add(message.serialized());
message.getRecipients().forEach((serviceId, recipient) -> {
for (byte device : recipient.getDevices()) {
final byte[] key = new byte[18];
System.arraycopy(serviceId.toServiceIdFixedWidthBinary(), 0, key, 0, 17);
key[17] = device;
args.add(key);
args.add(message.serializedRecipientView(recipient));
}
});
script.executeBinary(keys, args);
}
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
/**
* Removes a list of message GUIDs from the queue of a destination device.
*/
class MessagesCacheRemoveByGuidScript {
private final ClusterLuaScript removeByGuidScript;
MessagesCacheRemoveByGuidScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua",
ScriptOutputType.OBJECT);
}
CompletableFuture<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<UUID> messageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.toList();
//noinspection unchecked
return removeByGuidScript.executeBinaryAsync(keys, args)
.thenApply(result -> (List<byte[]>) result);
}
}

View File

@@ -0,0 +1,58 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes a device's queue from the cache. For a non-empty queue, this script must be executed multiple times.
* <ol>
* <li>The first call will return a list of messages to check for {@code sharedMrmKeys}. If a {@code sharedMrmKey} is present, {@link MessagesCacheRemoveRecipientViewFromMrmDataScript} must be called.</li>
* <li>Once theses messages have been processed, this script should be called again, confirming that the messages have been processed.</li>
* <li>This should be repeated until the script returns an empty list, as the script only returns a page ({@value PAGE_SIZE}) of messages at a time.</li>
* </ol>
*/
class MessagesCacheRemoveQueueScript {
private static final int PAGE_SIZE = 100;
private final ClusterLuaScript removeQueueScript;
MessagesCacheRemoveQueueScript(FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua",
ScriptOutputType.MULTI);
}
Mono<List<byte[]>> execute(final UUID destinationUuid, final byte destinationDevice,
final List<String> processedMessageGuids) {
final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey
);
final List<byte[]> args = new ArrayList<>();
args.addFirst(String.valueOf(PAGE_SIZE).getBytes(StandardCharsets.UTF_8)); // limit
args.addAll(processedMessageGuids.stream().map(guid -> guid.getBytes(StandardCharsets.UTF_8))
.toList()); // processedMessageGuids
//noinspection unchecked
return removeQueueScript.executeBinaryReactive(keys, args)
.map(result -> (List<byte[]>) result)
.next();
}
}

View File

@@ -0,0 +1,44 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import reactor.core.publisher.Mono;
/**
* Removes the given destination device from the given {@code sharedMrmKeys}. If there are no devices remaining in the
* hash as a result, the shared payload is deleted.
* <p>
* NOTE: Callers are responsible for ensuring that all keys are in the same slot.
*/
class MessagesCacheRemoveRecipientViewFromMrmDataScript {
private final ClusterLuaScript removeRecipientViewFromMrmDataScript;
MessagesCacheRemoveRecipientViewFromMrmDataScript(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.removeRecipientViewFromMrmDataScript = ClusterLuaScript.fromResource(redisCluster,
"lua/remove_recipient_view_from_mrm_data.lua", ScriptOutputType.INTEGER);
}
Mono<Long> execute(final Collection<byte[]> keysCollection, final AciServiceIdentifier serviceIdentifier,
final byte deviceId) {
final List<byte[]> keys = keysCollection instanceof List<byte[]>
? (List<byte[]>) keysCollection
: new ArrayList<>(keysCollection);
return removeRecipientViewFromMrmDataScript.executeBinaryReactive(keys,
List.of(MessagesCache.getSharedMrmViewKey(serviceIdentifier, deviceId)))
.map(o -> (long) o)
.next();
}
}

View File

@@ -19,8 +19,10 @@ import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair;
@@ -62,8 +64,8 @@ public class MessagesManager {
messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message);
if (message.hasSourceUuid() && !destinationUuid.toString().equals(message.getSourceUuid())) {
reportMessageManager.store(message.getSourceUuid(), messageGuid);
if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) {
reportMessageManager.store(message.getSourceServiceId(), messageGuid);
}
}
@@ -137,7 +139,7 @@ public class MessagesManager {
return messagesCache.clear(destinationUuid, deviceId);
}
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, Device destinationDevice, UUID guid,
public CompletableFuture<Optional<RemovedMessage>> delete(UUID destinationUuid, Device destinationDevice, UUID guid,
@Nullable Long serverTimestamp) {
return messagesCache.remove(destinationUuid, destinationDevice.getId(), guid)
.thenComposeAsync(removed -> {
@@ -146,12 +148,16 @@ public class MessagesManager {
return CompletableFuture.completedFuture(removed);
}
final CompletableFuture<Optional<MessageProtos.Envelope>> maybeDeletedEnvelope;
if (serverTimestamp == null) {
return messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDevice, guid);
maybeDeletedEnvelope = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid,
destinationDevice, guid);
} else {
return messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid, serverTimestamp);
maybeDeletedEnvelope = messagesDynamoDb.deleteMessage(destinationUuid, destinationDevice, guid,
serverTimestamp);
}
return maybeDeletedEnvelope.thenApply(maybeEnvelope -> maybeEnvelope.map(RemovedMessage::fromEnvelope));
}, messageDeletionExecutor);
}
@@ -194,4 +200,14 @@ public class MessagesManager {
messagesCache.removeMessageAvailabilityListener(listener);
}
/**
* Inserts the shared multi-recipient message payload to storage.
*
* @return a key where the shared data is stored
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/
public byte[] insertSharedMultiRecipientMessagePayload(
SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(UUID.randomUUID(), sealedSenderMultiRecipientMessage);
}
}

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.util.Optional;
import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public record RemovedMessage(Optional<ServiceIdentifier> sourceServiceId, ServiceIdentifier destinationServiceId,
@VisibleForTesting UUID serverGuid, long serverTimestamp, long clientTimestamp,
MessageProtos.Envelope.Type envelopeType) {
public static RemovedMessage fromEnvelope(MessageProtos.Envelope envelope) {
return new RemovedMessage(
envelope.hasSourceServiceId()
? Optional.of(ServiceIdentifier.valueOf(envelope.getSourceServiceId()))
: Optional.empty(),
ServiceIdentifier.valueOf(envelope.getDestinationServiceId()),
UUID.fromString(envelope.getServerGuid()),
envelope.getServerTimestamp(),
envelope.getClientTimestamp(),
envelope.getType()
);
}
}

View File

@@ -294,16 +294,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
}
private void sendDeliveryReceiptFor(Envelope message) {
if (!message.hasSourceUuid()) {
if (!message.hasSourceServiceId()) {
return;
}
try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()),
message.getTimestamp());
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceUuid());
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
} catch (Exception e) {
logger.warn("Failed to send receipt", e);
}

View File

@@ -205,7 +205,7 @@ record CommandDependencies(
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(),