Make Envelope the main unit of currency when working with stored messages

This commit is contained in:
Jon Chambers
2022-07-27 15:43:39 -04:00
committed by Jon Chambers
parent 3e0919106d
commit 3636626e09
9 changed files with 245 additions and 278 deletions

View File

@@ -21,7 +21,6 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -414,11 +413,19 @@ public class MessageController {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice()));
}
final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
userAgent,
false);
final OutgoingMessageEntityList outgoingMessages;
{
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
userAgent,
false);
outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream()
.map(OutgoingMessageEntity::fromEnvelope)
.collect(Collectors.toList()),
messagesAndHasMore.second());
}
{
String platform;
@@ -450,24 +457,22 @@ public class MessageController {
@DELETE
@Path("/uuid/{uuid}")
public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
try {
Optional<OutgoingMessageEntity> message = messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
uuid,
null);
messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
uuid,
null).ifPresent(deletedMessage -> {
if (message.isPresent()) {
WebSocketConnection.recordMessageDeliveryDuration(message.get().timestamp(), auth.getAuthenticatedDevice());
if (!Util.isEmpty(message.get().source())
&& message.get().type() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(auth, message.get().sourceUuid(), message.get().timestamp());
WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), auth.getAuthenticatedDevice());
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
try {
receiptSender.sendReceipt(auth, UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
}
}
} catch (NoSuchUserException e) {
logger.warn("Sending delivery receipt", e);
}
});
}
@Timed

View File

@@ -40,7 +40,6 @@ import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@@ -148,13 +147,13 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
guid.toString().getBytes(StandardCharsets.UTF_8))));
}
public Optional<OutgoingMessageEntity> remove(final UUID destinationUuid, final long destinationDevice,
public Optional<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst();
}
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> remove(final UUID destinationUuid, final long destinationDevice,
public List<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
final List<UUID> messageGuids) {
final List<byte[]> serialized = (List<byte[]>) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG,
REMOVE_METHOD_UUID).record(() ->
@@ -164,11 +163,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList())));
final List<OutgoingMessageEntity> removedMessages = new ArrayList<>(serialized.size());
final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
for (final byte[] bytes : serialized) {
try {
removedMessages.add(constructEntityFromEnvelope(MessageProtos.Envelope.parseFrom(bytes)));
removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
@@ -183,7 +182,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
public List<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> {
final List<byte[]> queueItems = (List<byte[]>) getItemsScript.executeBinary(
List.of(getMessageQueueKey(destinationUuid, destinationDevice),
@@ -193,7 +192,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
final long earliestAllowableEphemeralTimestamp =
System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
final List<OutgoingMessageEntity> messageEntities;
final List<MessageProtos.Envelope> messageEntities;
final List<UUID> staleEphemeralMessageGuids = new ArrayList<>();
if (queueItems.size() % 2 == 0) {
@@ -207,9 +206,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
continue;
}
final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8));
messageEntities.add(constructEntityFromEnvelope(message));
messageEntities.add(message);
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
@@ -379,21 +376,6 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
}
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
@VisibleForTesting
static String getQueueName(final UUID accountUuid, final long deviceId) {
return accountUuid + "::" + deviceId;

View File

@@ -112,7 +112,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems));
}
public List<OutgoingMessageEntity> load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) {
public List<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) {
return loadTimer.record(() -> {
final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE);
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@@ -128,9 +128,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId)))
.limit(numberOfMessagesToFetch)
.build();
List<OutgoingMessageEntity> messageEntities = new ArrayList<>(numberOfMessagesToFetch);
List<MessageProtos.Envelope> messageEntities = new ArrayList<>(numberOfMessagesToFetch);
for (Map<String, AttributeValue> message : db().queryPaginator(queryRequest).items()) {
messageEntities.add(convertItemToOutgoingMessageEntity(message));
messageEntities.add(convertItemToEnvelope(message));
if (messageEntities.size() == numberOfMessagesToFetch) {
// queryPaginator() uses limit() as the page size, not as an absolute limit
// …but a page might be smaller than limit, because a page is capped at 1 MB
@@ -141,7 +141,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
});
}
public Optional<OutgoingMessageEntity> deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid,
public Optional<MessageProtos.Envelope> deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid,
final UUID messageUuid) {
return deleteByGuid.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@@ -162,7 +162,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
});
}
public Optional<OutgoingMessageEntity> deleteMessage(final UUID destinationAccountUuid,
public Optional<MessageProtos.Envelope> deleteMessage(final UUID destinationAccountUuid,
final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
return deleteByKey.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@@ -173,7 +173,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.returnValues(ReturnValue.ALL_OLD);
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
return Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes()));
return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
}
return Optional.empty();
@@ -181,8 +181,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}
@Nonnull
private Optional<OutgoingMessageEntity> deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) {
Optional<OutgoingMessageEntity> result = Optional.empty();
private Optional<MessageProtos.Envelope> deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) {
Optional<MessageProtos.Envelope> result = Optional.empty();
for (Map<String, AttributeValue> item : db().queryPaginator(queryRequest).items()) {
final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray();
DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder()
@@ -193,7 +193,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
result = Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes()));
result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
}
}
return result;
@@ -233,19 +233,20 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
});
}
private OutgoingMessageEntity convertItemToOutgoingMessageEntity(Map<String, AttributeValue> message) {
final SortKey sortKey = convertSortKey(message.get(KEY_SORT).b().asByteArray());
final UUID messageUuid = convertLocalIndexMessageUuidSortKey(message.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray());
final int type = AttributeValues.getInt(message, KEY_TYPE, 0);
final long timestamp = AttributeValues.getLong(message, KEY_TIMESTAMP, 0L);
final String source = AttributeValues.getString(message, KEY_SOURCE, null);
final UUID sourceUuid = AttributeValues.getUUID(message, KEY_SOURCE_UUID, null);
final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null);
final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null);
private MessageProtos.Envelope convertItemToEnvelope(final Map<String, AttributeValue> item) {
final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray());
final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray());
final int type = AttributeValues.getInt(item, KEY_TYPE, 0);
final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L);
final String source = AttributeValues.getString(item, KEY_SOURCE, null);
final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null);
final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null);
final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid,
updatedPni, content, sortKey.getServerTimestamp());
updatedPni, content, sortKey.getServerTimestamp()).toEnvelope();
}
private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) {

View File

@@ -15,11 +15,10 @@ import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
public class MessagesManager {
@@ -61,10 +60,10 @@ public class MessagesManager {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public OutgoingMessageEntityList getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));
List<OutgoingMessageEntity> messageList = new ArrayList<>();
List<Envelope> messageList = new ArrayList<>();
if (!cachedMessagesOnly) {
messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE));
@@ -74,7 +73,7 @@ public class MessagesManager {
messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size()));
}
return new OutgoingMessageEntityList(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE);
return new Pair<>(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE);
}
public void clear(UUID destinationUuid) {
@@ -87,8 +86,8 @@ public class MessagesManager {
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId);
}
public Optional<OutgoingMessageEntity> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) {
Optional<OutgoingMessageEntity> removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid);
public Optional<Envelope> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) {
Optional<Envelope> removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid);
if (removed.isEmpty()) {
if (serverTimestamp == null) {

View File

@@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@@ -305,22 +306,25 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
try {
final OutgoingMessageEntityList messages = messagesManager
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager
.getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.messages().size()];
final List<Envelope> messages = messagesAndHasMore.first();
final boolean hasMore = messagesAndHasMore.second();
for (int i = 0; i < messages.messages().size(); i++) {
final OutgoingMessageEntity message = messages.messages().get(i);
final Envelope envelope = message.toEnvelope();
final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.size()];
for (int i = 0; i < messages.size(); i++) {
final Envelope envelope = messages.get(i);
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.guid(), message.serverTimestamp());
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp());
discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null);
} else {
sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(message.guid(), message.serverTimestamp())));
sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())));
}
}
@@ -329,7 +333,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
.orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)
.whenComplete((v, cause) -> {
if (cause == null) {
if (messages.more()) {
if (hasMore) {
sendNextMessagePage(cachedMessagesOnly, queueClearedFuture);
} else {
queueClearedFuture.complete(null);