Introduce and evaluate a client presence manager based on sharded pub/sub

This commit is contained in:
Jon Chambers
2024-11-05 15:51:29 -05:00
committed by GitHub
parent 60cdcf5f0c
commit 8c984cbf42
35 changed files with 1339 additions and 56 deletions

View File

@@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
/**
* A client event listener handles events related to a client's message-retrieval presence. Handler methods are run on
* dedicated threads and may safely perform blocking operations.
*/
public interface ClientEventListener {
/**
* Indicates that a new message is available in the connected client's message queue.
*/
void handleNewMessageAvailable();
/**
* Indicates that the client's presence has been displaced and the listener should close the client's underlying
* network connection.
*
* @param connectedElsewhere if {@code true}, indicates that the client's presence has been displaced by another
* connection from the same client
*/
void handleConnectionDisplaced(boolean connectedElsewhere);
}

View File

@@ -14,6 +14,7 @@ import io.lettuce.core.RedisFuture;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
@@ -277,7 +278,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
.subscribe(getKeyspaceNotificationChannel(presenceKey)));
}
private void resubscribeAll() {
private void resubscribeAll(final ClusterTopologyChangedEvent event) {
for (final String presenceKey : displacementListenersByPresenceKey.keySet()) {
subscribeForRemotePresenceChanges(presenceKey);
}

View File

@@ -8,9 +8,11 @@ import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import java.util.Objects;
/**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@@ -28,6 +30,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
public class MessageSender {
private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
@@ -35,15 +38,18 @@ public class MessageSender {
private static final String CHANNEL_TAG_NAME = "channel";
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String CLIENT_ONLINE_TAG_NAME = "clientOnline";
private static final String PUB_SUB_CLIENT_ONLINE_TAG_NAME = "pubSubClientOnline";
private static final String URGENT_TAG_NAME = "urgent";
private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
public MessageSender(final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager) {
this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
}
@@ -88,13 +94,15 @@ public class MessageSender {
}
}
Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, channel,
EPHEMERAL_TAG_NAME, String.valueOf(online),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment();
pubSubClientEventManager.handleNewMessageAvailable(account.getIdentifier(IdentityType.ACI), device.getId())
.whenComplete((present, throwable) -> Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, channel,
EPHEMERAL_TAG_NAME, String.valueOf(online),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
PUB_SUB_CLIENT_ONLINE_TAG_NAME, String.valueOf(Objects.requireNonNullElse(present, false)),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment());
}
}

View File

@@ -0,0 +1,407 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
/**
* The pub/sub-based client presence manager uses the Redis 7 sharded pub/sub system to notify connected clients that
* new messages are available for retrieval and report to senders whether a client was present to receive a message when
* sent. This system makes a best effort to ensure that a given client has only a single open connection across the
* fleet of servers, but cannot guarantee at-most-one behavior.
*/
public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[], byte[]> implements Managed {
private final FaultTolerantRedisClusterClient clusterClient;
private final Executor listenerEventExecutor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
static final String EXPERIMENT_NAME = "pubSubPresenceManager";
@Nullable
private FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubConnection;
private final Map<AccountAndDeviceIdentifier, ConnectionIdAndListener> listenersByAccountAndDeviceIdentifier;
private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder()
.setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
.build()
.toByteArray();
private static final byte[] DISCONNECT_REQUESTED_EVENT_BYTES = ClientEvent.newBuilder()
.setDisconnectRequested(DisconnectRequested.getDefaultInstance())
.build()
.toByteArray();
private static final Counter PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "publishClientConnectionEventError"));
private static final Counter UNSUBSCRIBE_ERROR_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "unsubscribeError"));
private static final Counter MESSAGE_WITHOUT_LISTENER_COUNTER =
Metrics.counter(MetricsUtil.name(PubSubClientEventManager.class, "messageWithoutListener"));
private static final String LISTENER_GAUGE_NAME =
MetricsUtil.name(PubSubClientEventManager.class, "listeners");
private static final Logger logger = LoggerFactory.getLogger(PubSubClientEventManager.class);
private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId) {
}
private record ConnectionIdAndListener(UUID connectionIdentifier, ClientEventListener listener) {
}
public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.clusterClient = clusterClient;
this.listenerEventExecutor = listenerEventExecutor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.listenersByAccountAndDeviceIdentifier =
Metrics.gaugeMapSize(LISTENER_GAUGE_NAME, Tags.empty(), new ConcurrentHashMap<>());
}
@Override
public synchronized void start() {
this.pubSubConnection = clusterClient.createBinaryPubSubConnection();
this.pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
pubSubConnection.subscribeToClusterTopologyChangedEvents(this::resubscribe);
}
@Override
public synchronized void stop() {
if (pubSubConnection != null) {
pubSubConnection.usePubSubConnection(connection -> {
connection.removeListener(this);
connection.close();
});
}
pubSubConnection = null;
}
/**
* Marks the given device as "present" and registers a listener for new messages and conflicting connections. If the
* given device already has a presence registered with this presence manager instance, that presence is displaced
* immediately and the listener's {@link ClientEventListener#handleConnectionDisplaced(boolean)} method is called.
*
* @param accountIdentifier the account identifier for the newly-connected device
* @param deviceId the ID of the newly-connected device within the given account
* @param listener the listener to notify when new messages or conflicting connections arrive for the newly-conencted
* device
*
* @return a future that yields a connection identifier when the new device's presence has been registered; the future
* may fail if a pub/sub subscription could not be established, in which case callers should close the client's
* connection to the server
*/
public CompletionStage<UUID> handleClientConnected(final UUID accountIdentifier, final byte deviceId, final ClientEventListener listener) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(UUID.randomUUID());
}
final UUID connectionId = UUID.randomUUID();
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
final AtomicReference<ClientEventListener> displacedListener = new AtomicReference<>();
final AtomicReference<CompletionStage<Void>> subscribeFuture = new AtomicReference<>();
// Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
// particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
//
// > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
// > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
// > computation is in progress, so the computation should be short and simple.
//
// This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(key, existingIdAndListener) -> {
subscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().ssubscribe(clientPresenceKey)));
if (existingIdAndListener != null) {
displacedListener.set(existingIdAndListener.listener());
}
return new ConnectionIdAndListener(connectionId, listener);
});
if (displacedListener.get() != null) {
listenerEventExecutor.execute(() -> displacedListener.get().handleConnectionDisplaced(true));
}
return subscribeFuture.get()
.thenCompose(ignored -> clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, buildClientConnectedMessage(connectionId))))
.handle((ignored, throwable) -> {
if (throwable != null) {
PUBLISH_CLIENT_CONNECTION_EVENT_ERROR_COUNTER.increment();
}
return connectionId;
});
}
/**
* Removes the "presence" for the given device. The presence is removed if and only if the given connection ID matches
* the connection ID for the currently-registered presence. Callers should call this method when they have closed or
* intend to close the client's underlying network connection.
*
* @param accountIdentifier the identifier of the account for the disconnected device
* @param deviceId the ID of the disconnected device within the given account
* @param connectionId the ID of the connection that has been closed (or will be closed)
*
* @return a future that completes when the presence has been removed
*/
public CompletionStage<Void> handleClientDisconnected(final UUID accountIdentifier, final byte deviceId, final UUID connectionId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(null);
}
final AtomicReference<CompletionStage<Void>> unsubscribeFuture = new AtomicReference<>();
// Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
// particular, the behavioral contract for `ConcurrentHashMap#compute(...)` says:
//
// > The entire method invocation is performed atomically. The supplied function is invoked exactly once per
// > invocation of this method. Some attempted update operations on this map by other threads may be blocked while
// > computation is in progress, so the computation should be short and simple.
//
// This provides a mechanism to make sure that we enqueue subscription/unsubscription operations in the same order
// as adding/removing listeners from the map and helps us avoid races and conflicts. Note that the enqueued
// operation is asynchronous; we're not blocking on it in the scope of the `compute` operation.
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(ignored, existingIdAndListener) -> {
final ConnectionIdAndListener remainingIdAndListener;
if (existingIdAndListener == null) {
remainingIdAndListener = null;
} else if (existingIdAndListener.connectionIdentifier().equals(connectionId)) {
remainingIdAndListener = null;
} else {
remainingIdAndListener = existingIdAndListener;
}
if (remainingIdAndListener == null) {
// Only unsubscribe if there's no listener remaining
unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId)))
.thenRun(Util.NOOP));
} else {
unsubscribeFuture.set(CompletableFuture.completedFuture(null));
}
return remainingIdAndListener;
});
return unsubscribeFuture.get()
.whenComplete((ignored, throwable) -> {
if (throwable != null) {
UNSUBSCRIBE_ERROR_COUNTER.increment();
}
});
}
/**
* Publishes an event notifying a specific device that a new message is available for retrieval. This method indicates
* whether the target device is "present" (i.e. has an active listener). Callers may choose to take follow-up action
* (like sending a push notification) if the target device is not present.
*
* @param accountIdentifier the account identifier of the receiving device
* @param deviceId the ID of the receiving device within the target account
*
* @return a future that yields {@code true} if the target device had an active listener or {@code false} otherwise
*/
public CompletionStage<Boolean> handleNewMessageAvailable(final UUID accountIdentifier, final byte deviceId) {
if (pubSubConnection == null) {
throw new IllegalStateException("Presence manager not started");
}
if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(false);
}
return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
.thenApply(listeners -> listeners > 0);
}
/**
* Tests whether a client with the given account/device is connected to this presence manager instance.
*
* @param accountUuid the account identifier for the client to check
* @param deviceId the ID of the device within the given account
*
* @return {@code true} if a client with the given account/device is connected to this presence manager instance or
* {@code false} if the client is not connected at all or is connected to a different presence manager instance
*/
public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) {
return listenersByAccountAndDeviceIdentifier.containsKey(new AccountAndDeviceIdentifier(accountUuid, deviceId));
}
/**
* Broadcasts a request that all devices associated with the identified account and connected to any client presence
* instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier) {
return requestDisconnection(accountIdentifier, Device.ALL_POSSIBLE_DEVICE_IDS);
}
/**
* Broadcasts a request that the specified devices associated with the identified account and connected to any client
* presence instance close their network connections.
*
* @param accountIdentifier the account identifier for which to request disconnection
* @param deviceIds the IDs of the devices for which to request disconnection
*
* @return a future that completes when the request has been sent
*/
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
return CompletableFuture.allOf(deviceIds.stream()
.map(deviceId -> {
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
return clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES))
.toCompletableFuture();
})
.toArray(CompletableFuture[]::new));
}
@VisibleForTesting
void resubscribe(final ClusterTopologyChangedEvent clusterTopologyChangedEvent) {
final boolean[] changedSlots = RedisClusterUtil.getChangedSlots(clusterTopologyChangedEvent);
final Map<Integer, List<byte[]>> clientPresenceKeysBySlot = new HashMap<>();
// Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands
listenersByAccountAndDeviceIdentifier.keySet()
.stream()
.map(accountAndDeviceIdentifier -> getClientPresenceKey(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId()))
.forEach(clientPresenceKey -> {
final int slot = SlotHash.getSlot(clientPresenceKey);
if (changedSlots[slot]) {
clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientPresenceKey);
}
});
// Issue one resubscription command per affected slot
clientPresenceKeysBySlot.forEach((slot, clientPresenceKeys) -> {
if (pubSubConnection != null) {
final byte[][] clientPresenceKeyArray = clientPresenceKeys.toArray(byte[][]::new);
pubSubConnection.usePubSubConnection(connection -> connection.sync().ssubscribe(clientPresenceKeyArray));
}
});
}
@Override
public void smessage(final RedisClusterNode node, final byte[] shardChannel, final byte[] message) {
final ClientEvent clientEvent;
try {
clientEvent = ClientEvent.parseFrom(message);
} catch (final InvalidProtocolBufferException e) {
logger.error("Failed to parse pub/sub event protobuf", e);
return;
}
final AccountAndDeviceIdentifier accountAndDeviceIdentifier = parseClientPresenceKey(shardChannel);
@Nullable final ConnectionIdAndListener connectionIdAndListener =
listenersByAccountAndDeviceIdentifier.get(accountAndDeviceIdentifier);
if (connectionIdAndListener != null) {
switch (clientEvent.getEventCase()) {
case NEW_MESSAGE_AVAILABLE -> connectionIdAndListener.listener().handleNewMessageAvailable();
case CLIENT_CONNECTED -> {
final UUID connectionId = UUIDUtil.fromByteString(clientEvent.getClientConnected().getConnectionId());
if (!connectionIdAndListener.connectionIdentifier().equals(connectionId)) {
listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(true));
}
}
case DISCONNECT_REQUESTED -> listenerEventExecutor.execute(() ->
connectionIdAndListener.listener().handleConnectionDisplaced(false));
default -> logger.warn("Unexpected client event type: {}", clientEvent.getClass());
}
} else {
MESSAGE_WITHOUT_LISTENER_COUNTER.increment();
}
}
private static byte[] buildClientConnectedMessage(final UUID connectionId) {
return ClientEvent.newBuilder()
.setClientConnected(ClientConnectedEvent.newBuilder()
.setConnectionId(UUIDUtil.toByteString(connectionId))
.build())
.build()
.toByteArray();
}
@VisibleForTesting
static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) {
return ("client_presence::{" + accountIdentifier + ":" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static AccountAndDeviceIdentifier parseClientPresenceKey(final byte[] clientPresenceKeyBytes) {
final String clientPresenceKey = new String(clientPresenceKeyBytes, StandardCharsets.UTF_8);
final int uuidStart = "client_presence::{".length();
final UUID accountIdentifier = UUID.fromString(clientPresenceKey.substring(uuidStart, uuidStart + 36));
final byte deviceId = Byte.parseByte(clientPresenceKey.substring(uuidStart + 37, clientPresenceKey.length() - 1));
return new AccountAndDeviceIdentifier(accountIdentifier, deviceId);
}
}