Shard push scheduling cache

This commit is contained in:
Jon Chambers
2021-01-19 15:50:12 -05:00
committed by GitHub
parent e600e9c583
commit 943a5d1036
7 changed files with 436 additions and 169 deletions

View File

@@ -107,6 +107,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private RedisConfiguration pushScheduler;
@NotNull
@Valid
@JsonProperty
private RedisClusterConfiguration pushSchedulerCluster;
@NotNull
@Valid
@JsonProperty
@@ -287,6 +292,10 @@ public class WhisperServerConfiguration extends Configuration {
return pushScheduler;
}
public RedisClusterConfiguration getPushSchedulerCluster() {
return pushSchedulerCluster;
}
public DatabaseConfiguration getMessageStoreConfiguration() {
return messageStore;
}

View File

@@ -281,10 +281,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ReplicatedJedisPool directoryClient = directoryClientFactory.getRedisClientPool();
ReplicatedJedisPool pushSchedulerClient = pushSchedulerClientFactory.getRedisClientPool();
ClientResources generalCacheClientResources = ClientResources.builder().build();
ClientResources messageCacheClientResources = ClientResources.builder().build();
ClientResources presenceClientResources = ClientResources.builder().build();
ClientResources metricsCacheClientResources = ClientResources.builder().build();
ClientResources generalCacheClientResources = ClientResources.builder().build();
ClientResources messageCacheClientResources = ClientResources.builder().build();
ClientResources presenceClientResources = ClientResources.builder().build();
ClientResources metricsCacheClientResources = ClientResources.builder().build();
ClientResources pushSchedulerCacheClientResources = ClientResources.builder().ioThreadPoolSize(4).build();
ConnectionEventLogger.logConnectionEvents(generalCacheClientResources);
ConnectionEventLogger.logConnectionEvents(messageCacheClientResources);
@@ -295,6 +296,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FaultTolerantRedisCluster messagesCluster = new FaultTolerantRedisCluster("message_insert_cluster", config.getMessageCacheConfiguration().getRedisClusterConfiguration(), messageCacheClientResources);
FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence_cluster", config.getClientPresenceClusterConfiguration(), presenceClientResources);
FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster", config.getMetricsClusterConfiguration(), metricsCacheClientResources);
FaultTolerantRedisCluster pushSchedulerCluster = new FaultTolerantRedisCluster("push_scheduler", config.getPushSchedulerCluster(), pushSchedulerCacheClientResources);
BlockingQueue<Runnable> keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(10_000);
Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), keyspaceNotificationDispatchQueue);
@@ -336,7 +338,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ExternalServiceCredentialGenerator backupCredentialsGenerator = new ExternalServiceCredentialGenerator(config.getSecureBackupServiceConfiguration().getUserAuthenticationTokenSharedSecret(), new byte[0], false);
ExternalServiceCredentialGenerator paymentsCredentialsGenerator = new ExternalServiceCredentialGenerator(config.getPaymentsServiceConfiguration().getUserAuthenticationTokenSharedSecret(), new byte[0], false);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushSchedulerClient, apnSender, accountsManager);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushSchedulerClient, pushSchedulerCluster, apnSender, accountsManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender);
MessageSender messageSender = new MessageSender(apnFallbackManager, clientPresenceManager, messagesManager, gcmSender, apnSender, pushLatencyManager);

View File

@@ -9,8 +9,14 @@ import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.RatioGauge;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.RedisException;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
@@ -19,89 +25,197 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.Util;
import redis.clients.jedis.exceptions.JedisException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.exceptions.JedisException;
public class ApnFallbackManager implements Managed, Runnable {
public class ApnFallbackManager implements Managed {
private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class);
private static final String PENDING_NOTIFICATIONS_KEY = "PENDING_APN";
private static final String SINGLETON_PENDING_NOTIFICATIONS_KEY = "PENDING_APN";
static final String NEXT_SLOT_TO_PERSIST_KEY = "pending_notification_next_slot";
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter delivered = metricRegistry.meter(name(ApnFallbackManager.class, "voip_delivered"));
private static final Meter sent = metricRegistry.meter(name(ApnFallbackManager.class, "voip_sent" ));
private static final Meter retry = metricRegistry.meter(name(ApnFallbackManager.class, "voip_retry"));
private static final Meter evicted = metricRegistry.meter(name(ApnFallbackManager.class, "voip_evicted"));
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter delivered = metricRegistry.meter(name(ApnFallbackManager.class, "voip_delivered"));
private static final Meter sent = metricRegistry.meter(name(ApnFallbackManager.class, "voip_sent" ));
private static final Meter retry = metricRegistry.meter(name(ApnFallbackManager.class, "voip_retry"));
private static final Meter evicted = metricRegistry.meter(name(ApnFallbackManager.class, "voip_evicted"));
private static final Meter singletonDestinations = metricRegistry.meter(name(ApnFallbackManager.class, "singleton_destinations"));
private static final Meter clusterDestinations = metricRegistry.meter(name(ApnFallbackManager.class, "cluster_destinations"));
static {
metricRegistry.register(name(ApnFallbackManager.class, "voip_ratio"), new VoipRatioGauge(delivered, sent));
}
private final APNSender apnSender;
private final AccountsManager accountsManager;
private final APNSender apnSender;
private final AccountsManager accountsManager;
private final FaultTolerantRedisCluster cluster;
private final ReplicatedJedisPool jedisPool;
private final InsertOperation insertOperation;
private final GetOperation getOperation;
private final RemoveOperation removeOperation;
private final LuaScript getSingletonScript;
private final LuaScript removeSingletonScript;
private final ClusterLuaScript getClusterScript;
private final ClusterLuaScript insertClusterScript;
private final ClusterLuaScript removeClusterScript;
private AtomicBoolean running = new AtomicBoolean(false);
private boolean finished;
private final Thread singletonWorkerThread;
private final Thread[] clusterWorkerThreads = new Thread[CLUSTER_WORKER_THREAD_COUNT];
private static final int CLUSTER_WORKER_THREAD_COUNT = 4;
private final AtomicBoolean running = new AtomicBoolean(false);
private class SingletonCacheWorker implements Runnable {
@Override
public void run() {
while (running.get()) {
try {
for (final String numberAndDevice : getPendingDestinationsFromSingletonCache(100)) {
singletonDestinations.mark();
Optional<Pair<String, Long>> separated = getSeparated(numberAndDevice);
if (!separated.isPresent()) {
removeFromSingleton(numberAndDevice);
continue;
}
Optional<Account> account = accountsManager.get(separated.get().first());
if (!account.isPresent()) {
removeFromSingleton(numberAndDevice);
continue;
}
Optional<Device> device = account.get().getDevice(separated.get().second());
if (!device.isPresent()) {
removeFromSingleton(numberAndDevice);
continue;
}
sendNotification(account.get(), device.get());
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
Util.sleep(1000);
}
}
}
class ClusterCacheWorker implements Runnable {
@Override
public void run() {
while (running.get()) {
try {
final long entriesProcessed = processNextSlot();
if (entriesProcessed == 0) {
Util.sleep(1000);
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
}
}
long processNextSlot() {
final int slot = getNextSlot();
List<String> pendingDestinations;
long entriesProcessed = 0;
do {
pendingDestinations = getPendingDestinationsFromClusterCache(slot, 100);
entriesProcessed += pendingDestinations.size();
for (final String uuidAndDevice : pendingDestinations) {
clusterDestinations.mark();
final Optional<Pair<String, Long>> separated = getSeparated(uuidAndDevice);
final Optional<Account> maybeAccount = separated.map(Pair::first)
.map(UUID::fromString)
.flatMap(accountsManager::get);
final Optional<Device> maybeDevice = separated.map(Pair::second)
.flatMap(deviceId -> maybeAccount.flatMap(account -> account.getDevice(deviceId)));
if (maybeAccount.isPresent() && maybeDevice.isPresent()) {
sendNotification(maybeAccount.get(), maybeDevice.get());
} else {
removeFromCluster(uuidAndDevice);
}
}
} while (!pendingDestinations.isEmpty());
return entriesProcessed;
}
}
public ApnFallbackManager(ReplicatedJedisPool jedisPool,
FaultTolerantRedisCluster cluster,
APNSender apnSender,
AccountsManager accountsManager)
throws IOException
{
this.apnSender = apnSender;
this.accountsManager = accountsManager;
this.jedisPool = jedisPool;
this.insertOperation = new InsertOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
}
this.cluster = cluster;
public void schedule(Account account, Device device) throws RedisException {
try {
sent.mark();
insertOperation.insert(account, device, System.currentTimeMillis() + (15 * 1000), (15 * 1000));
} catch (JedisException e) {
throw new RedisException(e);
this.getSingletonScript = LuaScript.fromResource(jedisPool, "lua/apn/get.lua");
this.removeSingletonScript = LuaScript.fromResource(jedisPool, "lua/apn/remove.lua");
this.getClusterScript = ClusterLuaScript.fromResource(cluster, "lua/apn/get.lua", ScriptOutputType.MULTI);
this.insertClusterScript = ClusterLuaScript.fromResource(cluster, "lua/apn/insert.lua", ScriptOutputType.VALUE);
this.removeClusterScript = ClusterLuaScript.fromResource(cluster, "lua/apn/remove.lua", ScriptOutputType.INTEGER);
this.singletonWorkerThread = new Thread(new SingletonCacheWorker(), "ApnFallbackManagerSingletonWorker");
for (int i = 0; i < this.clusterWorkerThreads.length; i++) {
this.clusterWorkerThreads[i] = new Thread(new ClusterCacheWorker(), "ApnFallbackManagerClusterWorker-" + i);
}
}
public boolean isScheduled(Account account, Device device) throws RedisException {
try {
String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
public void schedule(Account account, Device device) throws RedisException {
schedule(account, device, System.currentTimeMillis());
}
try (Jedis jedis = jedisPool.getReadResource()) {
return jedis.zscore(PENDING_NOTIFICATIONS_KEY, endpoint) != null;
}
} catch (JedisException e) {
@VisibleForTesting
void schedule(Account account, Device device, long timestamp) throws RedisException {
try {
sent.mark();
insert(account, device, timestamp + (15 * 1000), (15 * 1000));
} catch (io.lettuce.core.RedisException e) {
throw new RedisException(e);
}
}
public void cancel(Account account, Device device) throws RedisException {
try {
if (removeOperation.remove(account, device)) {
if (remove(account, device)) {
delivered.mark();
}
} catch (JedisException e) {
} catch (JedisException | io.lettuce.core.RedisException e) {
throw new RedisException(e);
}
}
@@ -109,77 +223,45 @@ public class ApnFallbackManager implements Managed, Runnable {
@Override
public synchronized void start() {
running.set(true);
new Thread(this).start();
singletonWorkerThread.start();
for (final Thread clusterWorkerThread : clusterWorkerThreads) {
clusterWorkerThread.start();
}
}
@Override
public synchronized void stop() {
public synchronized void stop() throws InterruptedException {
running.set(false);
while (!finished) Util.wait(this);
}
singletonWorkerThread.join();
@Override
public void run() {
while (running.get()) {
try {
List<byte[]> pendingNotifications = getOperation.getPending(100);
for (byte[] pendingNotification : pendingNotifications) {
String numberAndDevice = new String(pendingNotification);
Optional<Pair<String, Long>> separated = getSeparated(numberAndDevice);
if (!separated.isPresent()) {
removeOperation.remove(numberAndDevice);
continue;
}
Optional<Account> account = accountsManager.get(separated.get().first());
if (!account.isPresent()) {
removeOperation.remove(numberAndDevice);
continue;
}
Optional<Device> device = account.get().getDevice(separated.get().second());
if (!device.isPresent()) {
removeOperation.remove(numberAndDevice);
continue;
}
String apnId = device.get().getVoipApnId();
if (apnId == null) {
removeOperation.remove(account.get(), device.get());
continue;
}
long deviceLastSeen = device.get().getLastSeen();
if (deviceLastSeen < System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90)) {
evicted.mark();
removeOperation.remove(account.get(), device.get());
continue;
}
apnSender.sendMessage(new ApnMessage(apnId, separated.get().first(), separated.get().second(), true, Optional.empty()));
retry.mark();
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
Util.sleep(1000);
}
synchronized (ApnFallbackManager.this) {
finished = true;
notifyAll();
for (final Thread clusterWorkerThread : clusterWorkerThreads) {
clusterWorkerThread.join();
}
}
private Optional<Pair<String, Long>> getSeparated(String encoded) {
private void sendNotification(final Account account, final Device device) {
String apnId = device.getVoipApnId();
if (apnId == null) {
remove(account, device);
return;
}
long deviceLastSeen = device.getLastSeen();
if (deviceLastSeen < System.currentTimeMillis() - TimeUnit.DAYS.toMillis(90)) {
evicted.mark();
remove(account, device);
return;
}
apnSender.sendMessage(new ApnMessage(apnId, account.getNumber(), device.getId(), true, Optional.empty()));
retry.mark();
}
@VisibleForTesting
static Optional<Pair<String, Long>> getSeparated(String encoded) {
try {
if (encoded == null) return Optional.empty();
@@ -197,66 +279,78 @@ public class ApnFallbackManager implements Managed, Runnable {
}
}
private static class RemoveOperation {
private final LuaScript luaScript;
RemoveOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/remove.lua");
}
boolean remove(Account account, Device device) {
String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
return remove(endpoint);
}
boolean remove(String endpoint) {
if (!PENDING_NOTIFICATIONS_KEY.equals(endpoint)) {
List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes());
List<byte[]> args = Collections.emptyList();
return ((long)luaScript.execute(keys, args)) > 0;
}
return false;
}
private boolean remove(Account account, Device device) {
final boolean removedFromSingleton = removeFromSingleton(getSingletonEndpointKey(account, device));
final boolean removedFromCluster = removeFromCluster(getClusterEndpointKey(account, device));
return removedFromSingleton || removedFromCluster;
}
private static class GetOperation {
private boolean removeFromSingleton(String endpoint) {
if (!SINGLETON_PENDING_NOTIFICATIONS_KEY.equals(endpoint)) {
List<byte[]> keys = Arrays.asList(SINGLETON_PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes());
List<byte[]> args = Collections.emptyList();
private final LuaScript luaScript;
GetOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/get.lua");
return ((long)removeSingletonScript.execute(keys, args)) > 0;
}
@SuppressWarnings("SameParameterValue")
List<byte[]> getPending(int limit) {
List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes());
List<byte[]> args = Arrays.asList(String.valueOf(System.currentTimeMillis()).getBytes(), String.valueOf(limit).getBytes());
return (List<byte[]>) luaScript.execute(keys, args);
}
return false;
}
private static class InsertOperation {
private boolean removeFromCluster(final String endpoint) {
final long removed = (long)removeClusterScript.execute(List.of(getClusterPendingNotificationQueueKey(endpoint), endpoint),
Collections.emptyList());
private final LuaScript luaScript;
return removed > 0;
}
InsertOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/insert.lua");
}
@SuppressWarnings("unchecked")
private List<String> getPendingDestinationsFromSingletonCache(final int limit) {
List<byte[]> keys = List.of(SINGLETON_PENDING_NOTIFICATIONS_KEY.getBytes());
List<byte[]> args = List.of(String.valueOf(System.currentTimeMillis()).getBytes(), String.valueOf(limit).getBytes());
public void insert(Account account, Device device, long timestamp, long interval) {
String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
return ((List<byte[]>) getSingletonScript.execute(keys, args))
.stream()
.map(bytes -> new String(bytes, StandardCharsets.UTF_8))
.collect(Collectors.toList());
}
List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes());
List<byte[]> args = Arrays.asList(String.valueOf(timestamp).getBytes(), String.valueOf(interval).getBytes(),
account.getNumber().getBytes(), String.valueOf(device.getId()).getBytes());
@SuppressWarnings("unchecked")
@VisibleForTesting
List<String> getPendingDestinationsFromClusterCache(final int slot, final int limit) {
return (List<String>)getClusterScript.execute(List.of(getClusterPendingNotificationQueueKey(slot)),
List.of(String.valueOf(System.currentTimeMillis()), String.valueOf(limit)));
}
luaScript.execute(keys, args);
}
private void insert(final Account account, final Device device, final long timestamp, final long interval) {
final String endpoint = getClusterEndpointKey(account, device);
insertClusterScript.execute(List.of(getClusterPendingNotificationQueueKey(endpoint), endpoint),
List.of(String.valueOf(timestamp),
String.valueOf(interval),
account.getUuid().toString(),
String.valueOf(device.getId())));
}
private String getSingletonEndpointKey(final Account account, final Device device) {
return "apn_device::" + account.getNumber() + "::" + device.getId();
}
@VisibleForTesting
String getClusterEndpointKey(final Account account, final Device device) {
return "apn_device::{" + account.getUuid() + "::" + device.getId() + "}";
}
private String getClusterPendingNotificationQueueKey(final String endpoint) {
return getClusterPendingNotificationQueueKey(SlotHash.getSlot(endpoint));
}
private String getClusterPendingNotificationQueueKey(final int slot) {
return SINGLETON_PENDING_NOTIFICATIONS_KEY + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private int getNextSlot() {
return (int)(cluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT);
}
private static class VoipRatioGauge extends RatioGauge {