Scope disconnection request listeners to a single connection

This commit is contained in:
Jon Chambers
2025-07-23 12:13:04 -04:00
committed by Jon Chambers
parent 541c87e262
commit cf222e1105
13 changed files with 208 additions and 207 deletions

View File

@@ -5,20 +5,15 @@
package org.whispersystems.textsecuregcm.auth;
import java.util.Collection;
import java.util.UUID;
/**
* A disconnection request listener receives and handles requests to close authenticated client network connections.
* A disconnection request listener receives and handles a request to close an authenticated network connection for a
* specific client.
*/
public interface DisconnectionRequestListener {
/**
* Handles a request to close authenticated network connections for one or more authenticated devices. Requests are
* Handles a request to close an authenticated network connection for a specific authenticated device. Requests are
* dispatched on dedicated threads, and implementations may safely block.
*
* @param accountIdentifier the account identifier for which to close authenticated connections
* @param deviceIds the device IDs within the identified account for which to close authenticated connections
*/
void handleDisconnectionRequest(UUID accountIdentifier, Collection<Byte> deviceIds);
void handleDisconnectionRequest();
}

View File

@@ -5,21 +5,27 @@
package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.pubsub.RedisPubSubAdapter;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
@@ -37,11 +43,11 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil;
public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
private final FaultTolerantRedisClient pubSubClient;
private final GrpcClientConnectionManager grpcClientConnectionManager;
private final Executor listenerEventExecutor;
// We expect just a couple listeners to get added at startup time and not at all at steady-state. There are several
// reasonable ways to model this, but a copy-on-write list gives us good flexibility with minimal performance cost.
private final List<DisconnectionRequestListener> listeners = new CopyOnWriteArrayList<>();
private final Map<AccountIdentifierAndDeviceId, List<DisconnectionRequestListener>> listeners =
new ConcurrentHashMap<>();
@Nullable
private FaultTolerantPubSubConnection<byte[], byte[]> pubSubConnection;
@@ -56,10 +62,14 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
private static final Logger logger = LoggerFactory.getLogger(DisconnectionRequestManager.class);
private record AccountIdentifierAndDeviceId(UUID accountIdentifier, byte deviceId) {}
public DisconnectionRequestManager(final FaultTolerantRedisClient pubSubClient,
final GrpcClientConnectionManager grpcClientConnectionManager,
final Executor listenerEventExecutor) {
this.pubSubClient = pubSubClient;
this.grpcClientConnectionManager = grpcClientConnectionManager;
this.listenerEventExecutor = listenerEventExecutor;
}
@@ -85,13 +95,41 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
}
/**
* Adds a listener for disconnection requests. Listeners will receive all broadcast disconnection requests regardless
* of whether the device in connection is connected to this server.
* Adds a listener for disconnection requests for a specific authenticated device.
*
* @param accountIdentifier TODO
* @param deviceId TODO
* @param listener the listener to register
*/
public void addListener(final DisconnectionRequestListener listener) {
listeners.add(listener);
public void addListener(final UUID accountIdentifier, final byte deviceId, final DisconnectionRequestListener listener) {
listeners.compute(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), (_, existingListeners) -> {
final List<DisconnectionRequestListener> listeners =
existingListeners == null ? new ArrayList<>() : existingListeners;
listeners.add(listener);
return listeners;
});
}
/**
* Removes a listener for disconnection requests for a specific authenticated device.
*
* @param accountIdentifier TODO
* @param deviceId TODO
* @param listener the listener to remove
*/
public void removeListener(final UUID accountIdentifier, final byte deviceId, final DisconnectionRequestListener listener) {
listeners.computeIfPresent(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), (_, existingListeners) -> {
existingListeners.remove(listener);
return existingListeners.isEmpty() ? null : existingListeners;
});
}
@VisibleForTesting
List<DisconnectionRequestListener> getListeners(final UUID accountIdentifier, final byte deviceId) {
return listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList());
}
/**
@@ -154,12 +192,17 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
return;
}
for (final DisconnectionRequestListener listener : listeners) {
try {
listenerEventExecutor.execute(() -> listener.handleDisconnectionRequest(accountIdentifier, deviceIds));
} catch (final Exception e) {
logger.warn("Listener failed to handle disconnection request", e);
}
}
deviceIds.forEach(deviceId -> {
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(accountIdentifier, deviceId));
listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList())
.forEach(listener -> listenerEventExecutor.execute(() -> {
try {
listener.handleDisconnectionRequest();
} catch (final Exception e) {
logger.warn("Listener failed to handle disconnection request", e);
}
}));
});
}
}