Add a cluster-backed message cache.

This commit is contained in:
Jon Chambers
2020-07-09 09:34:20 -04:00
committed by Jon Chambers
parent 639898ec07
commit 6fc1b4c6c0
15 changed files with 690 additions and 59 deletions

View File

@@ -50,7 +50,6 @@ import io.micrometer.datadog.DatadogMeterRegistry;
import io.micrometer.wavefront.WavefrontConfig;
import io.micrometer.wavefront.WavefrontMeterRegistry;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.coursera.metrics.datadog.DatadogReporter;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.jdbi.v3.core.Jdbi;
import org.signal.zkgroup.ServerSecretParams;
@@ -137,6 +136,7 @@ import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.storage.RedisClusterMessagesCache;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
@@ -300,7 +300,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RedisClusterClient cacheClusterClient = RedisClusterClient.create(config.getCacheClusterConfiguration().getUrls().stream().map(RedisURI::create).collect(Collectors.toList()));
cacheClusterClient.setDefaultTimeout(config.getCacheClusterConfiguration().getTimeout());
FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", config.getCacheClusterConfiguration().getUrls(), config.getCacheClusterConfiguration().getTimeout(), config.getCacheClusterConfiguration().getCircuitBreakerConfiguration());
FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", config.getCacheClusterConfiguration().getUrls(), config.getCacheClusterConfiguration().getTimeout(), config.getCacheClusterConfiguration().getCircuitBreakerConfiguration());
FaultTolerantRedisCluster messagesCacheCluster = new FaultTolerantRedisCluster("messages_cluster", config.getMessageCacheConfiguration().getRedisClusterConfiguration().getUrls(), config.getMessageCacheConfiguration().getRedisClusterConfiguration().getTimeout(), config.getMessageCacheConfiguration().getRedisClusterConfiguration().getCircuitBreakerConfiguration());
DirectoryManager directory = new DirectoryManager(directoryClient);
DirectoryQueue directoryQueue = new DirectoryQueue(config.getDirectoryConfiguration().getSqsConfiguration());
@@ -309,7 +310,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountsManager accountsManager = new AccountsManager(accounts, directory, cacheCluster);
UsernamesManager usernamesManager = new UsernamesManager(usernames, reservedUsernames, cacheCluster);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes());
RedisClusterMessagesCache clusterMessagesCache = new RedisClusterMessagesCache(messagesCacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes(), clusterMessagesCache);
MessagesManager messagesManager = new MessagesManager(messages, messagesCache);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);

View File

@@ -12,6 +12,11 @@ public class MessageCacheConfiguration {
@Valid
private RedisConfiguration redis;
@JsonProperty
@NotNull
@Valid
private RedisClusterConfiguration cluster;
@JsonProperty
private int persistDelayMinutes = 10;
@@ -19,6 +24,10 @@ public class MessageCacheConfiguration {
return redis;
}
public RedisClusterConfiguration getRedisClusterConfiguration() {
return cluster;
}
public int getPersistDelayMinutes() {
return persistDelayMinutes;
}

View File

@@ -3,6 +3,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;
public class OutgoingMessageEntity {
@@ -114,4 +116,30 @@ public class OutgoingMessageEntity {
return serverTimestamp;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity)o;
return id == that.id &&
cached == that.cached &&
type == that.type &&
timestamp == that.timestamp &&
sourceDevice == that.sourceDevice &&
serverTimestamp == that.serverTimestamp &&
Objects.equals(guid, that.guid) &&
Objects.equals(relay, that.relay) &&
Objects.equals(source, that.source) &&
Objects.equals(sourceUuid, that.sourceUuid) &&
Arrays.equals(message, that.message) &&
Arrays.equals(content, that.content);
}
@Override
public int hashCode() {
int result = Objects.hash(id, cached, guid, type, relay, timestamp, source, sourceUuid, sourceDevice, serverTimestamp);
result = 31 * result + Arrays.hashCode(message);
result = 31 * result + Arrays.hashCode(content);
return result;
}
}

View File

@@ -5,10 +5,12 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.redis.LuaScript;
@@ -17,6 +19,9 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
import java.io.IOException;
import java.util.Arrays;
@@ -27,18 +32,17 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
public class MessagesCache implements Managed {
public class MessagesCache implements Managed, UserMessagesCache {
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" ));
@@ -54,50 +58,85 @@ public class MessagesCache implements Managed {
private final AccountsManager accountsManager;
private final int delayMinutes;
private InsertOperation insertOperation;
private RemoveOperation removeOperation;
private GetOperation getOperation;
private final InsertOperation insertOperation;
private final RemoveOperation removeOperation;
private final GetOperation getOperation;
private PubSubManager pubSubManager;
private PushSender pushSender;
private MessagePersister messagePersister;
public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
private final RedisClusterMessagesCache clusterMessagesCache;
private final ExecutorService experimentExecutor = new ThreadPoolExecutor(8, 8, 0, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(1_000));
private final Experiment insertExperiment = new Experiment("MessagesCache", "insert");
private final Experiment removeByIdExperiment = new Experiment("MessagesCache", "removeById");
private final Experiment removeBySenderExperiment = new Experiment("MessagesCache", "removeBySender");
private final Experiment removeByUuidExperiment = new Experiment("MessagesCache", "removeByUuid");
private final Experiment getMessagesExperiment = new Experiment("MessagesCache", "getMessages");
public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes, RedisClusterMessagesCache clusterMessagesCache) throws IOException {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.clusterMessagesCache = clusterMessagesCache;
}
public void insert(UUID guid, String destination, long destinationDevice, Envelope message) {
message = message.toBuilder().setServerGuid(guid.toString()).build();
@Override
public long insert(UUID guid, String destination, long destinationDevice, Envelope message) {
final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
Timer.Context timer = insertTimer.time();
try {
insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), message);
final long messageId = insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid);
insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationDevice, message, messageId), experimentExecutor);
return messageId;
} finally {
timer.stop();
}
}
public void remove(String destination, long destinationDevice, long id) {
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, long id) {
OutgoingMessageEntity removedMessageEntity = null;
try (Jedis jedis = jedisPool.getWriteResource();
Timer.Context ignored = removeByIdTimer.time())
{
removeOperation.remove(jedis, destination, destinationDevice, id);
byte[] serialized = removeOperation.remove(jedis, destination, destinationDevice, id);
if (serialized != null) {
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(id, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, id), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String sender, long timestamp) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByNameTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
if (serialized != null) {
Envelope envelope = Envelope.parseFrom(serialized);
return Optional.of(constructEntityFromEnvelope(0, envelope));
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@@ -105,18 +144,23 @@ public class MessagesCache implements Managed {
timer.stop();
}
return Optional.empty();
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, sender, timestamp), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, UUID guid) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByGuidTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, guid);
if (serialized != null) {
Envelope envelope = Envelope.parseFrom(serialized);
return Optional.of(constructEntityFromEnvelope(0, envelope));
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@@ -124,9 +168,14 @@ public class MessagesCache implements Managed {
timer.stop();
}
return Optional.empty();
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, guid), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit) {
Timer.Context timer = getTimer.time();
@@ -139,18 +188,21 @@ public class MessagesCache implements Managed {
try {
long id = item.second().longValue();
Envelope message = Envelope.parseFrom(item.first());
results.add(constructEntityFromEnvelope(id, message));
results.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationDevice, limit), experimentExecutor);
return results;
} finally {
timer.stop();
}
}
@Override
public void clear(String destination) {
Timer.Context timer = clearAccountTimer.time();
@@ -163,6 +215,7 @@ public class MessagesCache implements Managed {
}
}
@Override
public void clear(String destination, long deviceId) {
Timer.Context timer = clearDeviceTimer.time();
@@ -180,11 +233,7 @@ public class MessagesCache implements Managed {
@Override
public void start() throws Exception {
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
this.messagePersister.start();
}
@@ -192,20 +241,8 @@ public class MessagesCache implements Managed {
public void stop() throws Exception {
messagePersister.shutdown();
logger.info("Message persister shut down...");
}
private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
this.experimentExecutor.shutdown();
}
private static class Key {
@@ -271,14 +308,14 @@ public class MessagesCache implements Managed {
this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
}
public void insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
public long insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
Key key = new Key(destination, destinationDevice);
String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes(), guid.toString().getBytes());
insert.execute(keys, args);
return (long)insert.execute(keys, args);
}
}
@@ -296,13 +333,13 @@ public class MessagesCache implements Managed {
this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
}
public void remove(Jedis jedis, String destination, long destinationDevice, long id) {
public byte[] remove(Jedis jedis, String destination, long destinationDevice, long id) {
Key key = new Key(destination, destinationDevice);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(String.valueOf(id).getBytes());
this.removeById.execute(jedis, keys, args);
return (byte[])this.removeById.execute(jedis, keys, args);
}
public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {

View File

@@ -0,0 +1,206 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.ScriptOutputType;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagesCache implements UserMessagesCache {
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByIdScript;
private final ClusterLuaScript removeBySenderScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
private static final String GET_TIMER_NAME = name(RedisClusterMessagesCache.class, "get");
private static final String CLEAR_TIMER_NAME = name(RedisClusterMessagesCache.class, "clear");
private static final String REMOVE_METHOD_TAG = "method";
private static final String REMOVE_METHOD_ID = "id";
private static final String REMOVE_METHOD_SENDER = "sender";
private static final String REMOVE_METHOD_UUID = "uuid";
private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagesCache.class);
public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
}
@Override
public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
sender.getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8))));
}
public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
sender.getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8))));
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final long id) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() ->
removeByIdScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final String sender, final long timestamp) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() ->
removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final UUID guid) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getPersistInProgressKey(destination, destinationDevice)),
List.of(String.valueOf(limit).getBytes()));
final List<OutgoingMessageEntity> messageEntities;
if (queueItems.size() % 2 == 0) {
messageEntities = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8));
messageEntities.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
} else {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
messageEntities = Collections.emptyList();
}
return messageEntities;
});
}
@Override
public void clear(final String destination) {
for (int i = 1; i < 256; i++) {
clear(destination, i);
}
}
@Override
public void clear(final String destination, final long deviceId) {
Metrics.timer(CLEAR_TIMER_NAME).record(() ->
removeQueueScript.executeBinary(List.of(getMessageQueueKey(destination, deviceId),
getMessageQueueMetadataKey(destination, deviceId),
getQueueIndexKey(destination, deviceId)),
Collections.emptyList()));
}
private static byte[] getMessageQueueKey(final String address, final long deviceId) {
return ("user_queue::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getMessageQueueMetadataKey(final String address, final long deviceId) {
return ("user_queue_metadata::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getQueueIndexKey(final String address, final long deviceId) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(address + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getPersistInProgressKey(final String address, final long deviceId) {
return ("user_queue_persisting::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
}

View File

@@ -0,0 +1,41 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.PushSender;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface UserMessagesCache {
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
long insert(UUID guid, String destination, long destinationDevice, MessageProtos.Envelope message);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, long id);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String sender, long timestamp);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, UUID guid);
List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit);
void clear(String destination);
void clear(String destination, long deviceId);
}

View File

@@ -0,0 +1,37 @@
package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
public class RedisClusterUtil {
private static final String[] HASHES_BY_SLOT = new String[SlotHash.SLOT_COUNT];
static {
int slotsCovered = 0;
int i = 0;
while (slotsCovered < HASHES_BY_SLOT.length) {
final String hash = Integer.toString(i++, 36);
final int slot = SlotHash.getSlot(hash);
if (HASHES_BY_SLOT[slot] == null) {
HASHES_BY_SLOT[slot] = hash;
slotsCovered += 1;
}
}
}
/**
* Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
*
* @param key the key for which to find a matching hash tag
* @return a Redis hash tag that maps to the same Redis cluster slot as the given key
*
* @see <a href="https://redis.io/topics/cluster-spec#keys-hash-tags">Redis Cluster Specification - Keys hash tags</a>
*/
public static String getMinimalHashTag(final String key) {
return HASHES_BY_SLOT[SlotHash.getSlot(key)];
}
}