Don't cache authenticated accounts in memory

This commit is contained in:
Jon Chambers
2025-06-23 08:40:05 -05:00
committed by GitHub
parent 9dfe51eac4
commit c952baa672
86 changed files with 961 additions and 2264 deletions

View File

@@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -20,7 +21,10 @@ import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
@@ -36,6 +40,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class);
private final AccountsManager accountsManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics;
@@ -51,7 +56,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
public AuthenticatedConnectListener(
AccountsManager accountsManager,
ReceiptSender receiptSender,
MessagesManager messagesManager,
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
@@ -62,6 +69,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.accountsManager = accountsManager;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
@@ -82,7 +91,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
public void onWebSocketConnect(final WebSocketSessionContext context) {
final boolean authenticated = (context.getAuthenticated() != null);
final OpenWebSocketCounter openWebSocketCounter =
@@ -92,12 +101,24 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Optional<Account> maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier());
final Optional<Device> maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.getDeviceId()));;
if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) {
log.warn("{}:{} not found when opening authenticated WebSocket", auth.getAccountIdentifier(), auth.getDeviceId());
context.getClient().close(1011, "Unexpected error initializing connection");
return;
}
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient(),
scheduledExecutorService,
messageDeliveryScheduler,
@@ -110,8 +131,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// receive push notifications for inbound messages. We should do this first because, at this point, the
// connection has already closed and attempts to actually deliver a message via the connection will not succeed.
// It's preferable to start sending push notifications as soon as possible.
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId());
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccountIdentifier(), auth.getDeviceId());
// Finally, stop trying to deliver messages and send a push notification if the connection is aware of any
// undelivered messages.
@@ -127,7 +147,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
webSocketConnectionEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
webSocketConnectionEventManager.handleClientConnected(auth.getAccountIdentifier(), auth.getDeviceId(), connection);
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection");

View File

@@ -9,41 +9,39 @@ import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentials
import com.google.common.net.HttpHeaders;
import javax.annotation.Nullable;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Optional;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> {
private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedDevice> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
final PrincipalSupplier<AuthenticatedDevice> principalSupplier) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
}
@Override
public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request)
public Optional<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
return Optional.empty();
}
return basicCredentialsFromAuthHeader(authHeader)
.flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
final BasicCredentials credentials = basicCredentialsFromAuthHeader(authHeader)
.orElseThrow(InvalidCredentialsException::new);
final AuthenticatedDevice authenticatedDevice = accountAuthenticator.authenticate(credentials)
.orElseThrow(InvalidCredentialsException::new);
return Optional.of(authenticatedDevice);
}
}

View File

@@ -37,7 +37,6 @@ import org.eclipse.jetty.util.StaticException;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@@ -52,6 +51,7 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@@ -123,7 +123,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AuthenticatedDevice auth;
private final Account authenticatedAccount;
private final Device authenticatedDevice;
private final WebSocketClient client;
private final int sendFuturesTimeoutMillis;
@@ -156,7 +157,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
@@ -169,7 +171,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
authenticatedAccount,
authenticatedDevice,
client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService,
@@ -184,7 +187,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService,
@@ -198,7 +202,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.auth = auth;
this.authenticatedAccount = authenticatedAccount;
this.authenticatedDevice = authenticatedDevice;
this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
this.scheduledExecutorService = scheduledExecutorService;
@@ -209,7 +214,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
}
public void start() {
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), client.getUserAgent());
pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent());
queueDrainStartTime.set(System.currentTimeMillis());
processStoredMessages();
}
@@ -229,8 +234,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
client.close(1000, "OK");
if (storedMessageState.get() != StoredMessageState.EMPTY) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(),
auth.getAuthenticatedDevice(),
pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount,
authenticatedDevice,
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
}
}
@@ -242,7 +247,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
sendMessageCounter.increment();
sentMessageCounter.increment();
bytesSentCounter.increment(body.map(bytes -> bytes.length).orElse(0));
messageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message);
messageMetrics.measureAccountEnvelopeUuidMismatches(authenticatedAccount, message);
// X-Signal-Key: false must be sent until Android stops assuming it missing means true
return client.sendRequest("PUT", "/api/v1/message",
@@ -253,7 +258,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
} else {
messageMetrics.measureOutgoingMessageLatency(message.getServerTimestamp(),
"websocket",
auth.getAuthenticatedDevice().isPrimary(),
authenticatedDevice.isPrimary(),
message.getUrgent(),
message.getEphemeral(),
client.getUserAgent(),
@@ -263,12 +268,12 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> result;
if (isSuccessResponse(response)) {
result = messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(),
result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice,
storedMessageInfo.guid(), storedMessageInfo.serverTimestamp())
.thenApply(ignored -> null);
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
recordMessageDeliveryDuration(message.getServerTimestamp(), auth.getAuthenticatedDevice());
recordMessageDeliveryDuration(message.getServerTimestamp(), authenticatedDevice);
sendDeliveryReceiptFor(message);
}
} else {
@@ -307,7 +312,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
@@ -338,7 +343,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
// Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
@@ -399,7 +403,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly);
messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final AtomicBoolean hasErrored = new AtomicBoolean();
@@ -410,8 +414,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.doOnNext(envelope -> {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI),
authenticatedDevice.getId(),
UUID.fromString(envelope.getServerGuid()),
client.getUserAgent(),
"websocket");
@@ -471,7 +475,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getStory() && !client.shouldDeliverStories()) {
messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), messageGuid, envelope.getServerTimestamp());
messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp());
return CompletableFuture.completedFuture(null);
} else {