Add a cluster-capable message persister

This commit is contained in:
Jon Chambers
2020-07-21 11:47:53 -04:00
parent f9f93c77e2
commit beac73b6c8
5 changed files with 500 additions and 27 deletions

View File

@@ -0,0 +1,181 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagePersister implements Managed {
private final RedisClusterMessagesCache messagesCache;
private final Messages messagesDatabase;
private final PubSubManager pubSubManager;
private final PushSender pushSender;
private final AccountsManager accountsManager;
private final Duration persistDelay;
private volatile boolean running = false;
private Thread workerThread;
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "getQueues"));
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "persistQueue"));
private static final Timer NOTIFY_SUBSCRIBERS_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "notifySubscribers"));
private static final DistributionSummary QUEUE_COUNT_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueCount"));
private static final DistributionSummary QUEUE_SIZE_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueSize"));
static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100;
private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagePersister.class);
public RedisClusterMessagePersister(final RedisClusterMessagesCache messagesCache, final Messages messagesDatabase, final PubSubManager pubSubManager, final PushSender pushSender, final AccountsManager accountsManager, final Duration persistDelay) {
this.messagesCache = messagesCache;
this.messagesDatabase = messagesDatabase;
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.persistDelay = persistDelay;
}
@Override
public void start() {
running = true;
workerThread = new Thread(() -> {
while (running) {
persistNextQueues(Instant.now());
Util.sleep(100);
}
});
workerThread.start();
}
@Override
public void stop() throws Exception {
running = false;
if (workerThread != null) {
workerThread.join();
workerThread = null;
}
}
@VisibleForTesting
void persistNextQueues(final Instant currentTime) {
final int slot = messagesCache.getNextSlotToPersist();
List<String> queuesToPersist;
int queuesPersisted = 0;
do {
queuesToPersist = GET_QUEUES_TIMER.record(() -> messagesCache.getQueuesToPersist(slot, currentTime.minus(persistDelay), QUEUE_BATCH_LIMIT));
for (final String queue : queuesToPersist) {
persistQueue(queue);
notifyClients(RedisClusterMessagesCache.getAccountUuidFromQueueName(queue), RedisClusterMessagesCache.getDeviceIdFromQueueName(queue));
}
queuesPersisted += queuesToPersist.size();
} while (queuesToPersist.size() == QUEUE_BATCH_LIMIT);
QUEUE_COUNT_SUMMARY.record(queuesPersisted);
}
@VisibleForTesting
void persistQueue(final String queue) {
final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queue);
final long deviceId = RedisClusterMessagesCache.getDeviceIdFromQueueName(queue);
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
final String accountNumber;
if (maybeAccount.isPresent()) {
accountNumber = maybeAccount.get().getNumber();
} else {
logger.error("No account record found for account {}", accountUuid);
return;
}
PERSIST_QUEUE_TIMER.record(() -> {
messagesCache.lockQueueForPersistence(queue);
try {
int messageCount = 0;
List<MessageProtos.Envelope> messages;
do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
for (final MessageProtos.Envelope message : messages) {
final UUID uuid = UUID.fromString(message.getServerGuid());
messagesDatabase.store(uuid, message, accountNumber, deviceId);
messagesCache.remove(accountNumber, accountUuid, deviceId, uuid);
messageCount++;
}
} while (messages.size() == MESSAGE_BATCH_LIMIT);
QUEUE_SIZE_SUMMARY.record(messageCount);
} finally {
messagesCache.unlockQueueForPersistence(queue);
}
});
}
public void notifyClients(final UUID accountUuid, final long deviceId) {
NOTIFY_SUBSCRIBERS_TIMER.record(() -> {
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
final String address;
if (maybeAccount.isPresent()) {
address = maybeAccount.get().getNumber();
} else {
logger.error("No account record found for account {}", accountUuid);
return;
}
final boolean notified = pubSubManager.publish(new WebsocketAddress(address, deviceId),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build());
if (!notified) {
Optional<Account> account = accountsManager.get(address);
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(deviceId);
if (device.isPresent()) {
try {
pushSender.sendQueuedNotification(account.get(), device.get());
} catch (final NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!");
}
}
}
}
});
}
}

View File

@@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -13,6 +15,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -23,12 +26,18 @@ import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagesCache implements UserMessagesCache {
private final FaultTolerantRedisCluster redisCluster;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByIdScript;
private final ClusterLuaScript removeBySenderScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private final ClusterLuaScript getQueuesToPersistScript;
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
@@ -44,12 +53,15 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
this.redisCluster = redisCluster;
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI);
}
@Override
@@ -122,13 +134,13 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID guid) {
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
@@ -142,11 +154,11 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
@Override
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(String destination, final UUID destinationUuid, long destinationDevice, int limit) {
public List<OutgoingMessageEntity> get(final String destination, final UUID destinationUuid, final long destinationDevice, final int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(limit).getBytes()));
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
final List<OutgoingMessageEntity> messageEntities;
@@ -172,6 +184,35 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
});
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final long destinationDevice, final int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(accountUuid, destinationDevice),
getPersistInProgressKey(accountUuid, destinationDevice)),
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
final List<MessageProtos.Envelope> envelopes;
if (queueItems.size() % 2 == 0) {
envelopes = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size(); i += 2) {
try {
envelopes.add(MessageProtos.Envelope.parseFrom(queueItems.get(i)));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
} else {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
envelopes = Collections.emptyList();
}
return envelopes;
});
}
@Override
public void clear(final String destination, final UUID destinationUuid) {
// TODO Remove null check in a fully UUID-based world
@@ -191,7 +232,27 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
Collections.emptyList()));
}
private static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
int getNextSlotToPersist() {
return (int)(redisCluster.withWriteCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT);
}
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
//noinspection unchecked
return (List<String>)getQueuesToPersistScript.execute(List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
List.of(String.valueOf(maxTime.toEpochMilli()),
String.valueOf(limit)));
}
void lockQueueForPersistence(final String queue) {
redisCluster.useBinaryWriteCluster(connection -> connection.sync().setex(getPersistInProgressKey(queue), 30, LOCK_VALUE));
}
void unlockQueueForPersistence(final String queue) {
redisCluster.useBinaryWriteCluster(connection -> connection.sync().del(getPersistInProgressKey(queue)));
}
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
@@ -199,11 +260,29 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(accountUuid.toString() + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
}
private byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
return ("user_queue_persisting::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
private static byte[] getQueueIndexKey(final int slot) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
return getPersistInProgressKey(accountUuid + "::" + deviceId);
}
private static byte[] getPersistInProgressKey(final String queueName) {
return ("user_queue_persisting::{" + queueName + "}").getBytes(StandardCharsets.UTF_8);
}
static UUID getAccountUuidFromQueueName(final String queueName) {
final int startOfHashTag = queueName.indexOf('{');
return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag)));
}
static long getDeviceIdFromQueueName(final String queueName) {
return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
}
}

View File

@@ -1,8 +1,6 @@
package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
public class RedisClusterUtil {
@@ -23,15 +21,7 @@ public class RedisClusterUtil {
}
}
/**
* Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
*
* @param key the key for which to find a matching hash tag
* @return a Redis hash tag that maps to the same Redis cluster slot as the given key
*
* @see <a href="https://redis.io/topics/cluster-spec#keys-hash-tags">Redis Cluster Specification - Keys hash tags</a>
*/
public static String getMinimalHashTag(final String key) {
return HASHES_BY_SLOT[SlotHash.getSlot(key)];
public static String getMinimalHashTag(final int slot) {
return HASHES_BY_SLOT[slot];
}
}