Revert "Use MessageStream in WebSocketConnection"

This reverts commit 470e17963a.
This commit is contained in:
Jon Chambers
2025-08-13 15:46:53 -04:00
committed by Jon Chambers
parent 0f2a4d02e0
commit a94ce72894
6 changed files with 1022 additions and 763 deletions

View File

@@ -994,9 +994,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, disconnectionRequestManager,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager
));
pushNotificationScheduler, redisMessageAvailabilityManager, disconnectionRequestManager,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);

View File

@@ -7,11 +7,10 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Tags;
import java.time.Duration;
import java.util.Optional;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@@ -24,12 +23,12 @@ import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import reactor.core.scheduler.Scheduler;
@@ -46,18 +45,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class);
private final AccountsManager accountsManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics;
private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler;
private final RedisMessageAvailabilityManager redisMessageAvailabilityManager;
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebSocketConnectionBuilder webSocketConnectionBuilder;
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
@VisibleForTesting
@FunctionalInterface
interface WebSocketConnectionBuilder {
WebSocketConnection buildWebSocketConnection(Account account, Device device, WebSocketClient client);
}
public AuthenticatedConnectListener(
final AccountsManager accountsManager,
final ReceiptSender receiptSender,
@@ -65,63 +67,47 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final MessageMetrics messageMetrics,
final PushNotificationManager pushNotificationManager,
final PushNotificationScheduler pushNotificationScheduler,
final RedisMessageAvailabilityManager redisMessageAvailabilityManager,
final DisconnectionRequestManager disconnectionRequestManager,
final Scheduler messageDeliveryScheduler,
final ClientReleaseManager clientReleaseManager,
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this(accountsManager,
disconnectionRequestManager,
(account, device, client) -> new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
account,
device,
client,
messageDeliveryScheduler,
clientReleaseManager,
messageDeliveryLoopMonitor,
experimentEnrollmentManager),
authenticated -> new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME,
NEW_CONNECTION_COUNTER_NAME,
CONNECTED_DURATION_TIMER_NAME,
Duration.ofHours(3),
Tags.of(AUTHENTICATED_TAG_NAME, String.valueOf(authenticated)))
);
}
@VisibleForTesting AuthenticatedConnectListener(
final AccountsManager accountsManager,
final DisconnectionRequestManager disconnectionRequestManager,
final WebSocketConnectionBuilder webSocketConnectionBuilder,
final Function<Boolean, OpenWebSocketCounter> openWebSocketCounterBuilder) {
this.accountsManager = accountsManager;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.redisMessageAvailabilityManager = redisMessageAvailabilityManager;
this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionBuilder = webSocketConnectionBuilder;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
openAuthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(true);
openUnauthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(false);
openAuthenticatedWebSocketCounter =
new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "true"));
openUnauthenticatedWebSocketCounter =
new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "false"));
}
@Override
public void onWebSocketConnect(final WebSocketSessionContext context) {
final boolean authenticated = (context.getAuthenticated() != null);
final OpenWebSocketCounter openWebSocketCounter =
authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter;
(authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context);
openWebSocketCounter.countOpenWebSocket(context);
if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Optional<Account> maybeAuthenticatedAccount =
accountsManager.getByAccountIdentifier(auth.accountIdentifier());
final Optional<Device> maybeAuthenticatedDevice =
maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId()));
final Optional<Account> maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.accountIdentifier());
final Optional<Device> maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId()));
if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) {
log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId());
@@ -130,10 +116,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
return;
}
final WebSocketConnection connection =
webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient());
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient(),
messageDeliveryScheduler,
clientReleaseManager,
messageDeliveryLoopMonitor,
experimentEnrollmentManager);
disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI),
maybeAuthenticatedDevice.get().getId(),
@@ -144,11 +138,27 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
maybeAuthenticatedDevice.get().getId(),
connection);
// We begin the shutdown process by removing this client's "presence," which means it will again begin to
// receive push notifications for inbound messages. We should do this first because, at this point, the
// connection has already closed and attempts to actually deliver a message via the connection will not succeed.
// It's preferable to start sending push notifications as soon as possible.
redisMessageAvailabilityManager.handleClientDisconnected(auth.accountIdentifier(), auth.deviceId());
// Finally, stop trying to deliver messages and send a push notification if the connection is aware of any
// undelivered messages.
connection.stop();
});
try {
// Once we "start" the websocket connection, we'll cancel any scheduled "you may have new messages" push
// notifications and begin delivering any stored messages for the connected device. We have not yet declared the
// client as "present" yet. If a message arrives at this point, we will update the message availability state
// correctly, but we may also send a spurious push notification.
connection.start();
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
redisMessageAvailabilityManager.handleClientConnected(auth.accountIdentifier(), auth.deviceId(), connection);
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection");

View File

@@ -14,17 +14,22 @@ import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.util.StaticException;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener;
@@ -37,28 +42,26 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.ConflictingMessageConsumerException;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageStream;
import org.whispersystems.textsecuregcm.storage.MessageStreamEntry;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.Disposable;
import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.util.retry.Retry;
public class WebSocketConnection implements DisconnectionRequestListener {
public class WebSocketConnection implements MessageAvailabilityListener, DisconnectionRequestListener {
private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage"));
private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent"));
@@ -75,15 +78,17 @@ public class WebSocketConnection implements DisconnectionRequestListener {
"sendMessages");
private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class,
"sendMessageError");
private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable");
private static final String MESSAGES_PERSISTED_COUNTER_NAME = name(WebSocketConnection.class, "messagesPersisted");
private static final String SEND_MESSAGE_DURATION_TIMER_NAME = name(WebSocketConnection.class, "sendMessageDuration");
private static final String PRESENCE_MANAGER_TAG = "presenceManager";
private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message";
private static final String ERROR_TYPE_TAG = "errorType";
private static final String EXCEPTION_TYPE_TAG = "exceptionType";
private static final String CONNECTED_ELSEWHERE_TAG = "connectedElsewhere";
private static final Duration SLOW_DRAIN_THRESHOLD = Duration.ofSeconds(10);
private static final long SLOW_DRAIN_THRESHOLD = 10_000;
@VisibleForTesting
static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100;
@@ -103,21 +108,28 @@ public class WebSocketConnection implements DisconnectionRequestListener {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final Retry retrySpec;
private final Account authenticatedAccount;
private final Device authenticatedDevice;
private final MessageStream messageStream;
private final WebSocketClient client;
private final Tags platformTag;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(
StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartNanoTime = new AtomicLong();
private final AtomicReference<Disposable> messageSubscription = new AtomicReference<>();
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
private enum StoredMessageState {
EMPTY,
CACHED_NEW_MESSAGES_AVAILABLE,
PERSISTED_NEW_MESSAGES_AVAILABLE
}
public WebSocketConnection(final ReceiptSender receiptSender,
final MessagesManager messagesManager,
final MessageMetrics messageMetrics,
@@ -131,38 +143,6 @@ public class WebSocketConnection implements DisconnectionRequestListener {
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
authenticatedAccount,
authenticatedDevice,
client,
messageDeliveryScheduler,
clientReleaseManager,
messageDeliveryLoopMonitor,
experimentEnrollmentManager,
Retry.backoff(4, Duration.ofSeconds(1))
.maxBackoff(Duration.ofSeconds(2))
.filter(throwable -> !isConnectionClosedException(throwable)));
}
@VisibleForTesting
WebSocketConnection(final ReceiptSender receiptSender,
final MessagesManager messagesManager,
final MessageMetrics messageMetrics,
final PushNotificationManager pushNotificationManager,
final PushNotificationScheduler pushNotificationScheduler,
final Account authenticatedAccount,
final Device authenticatedDevice,
final WebSocketClient client,
final Scheduler messageDeliveryScheduler,
final ClientReleaseManager clientReleaseManager,
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final Retry retrySpec) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
@@ -175,79 +155,12 @@ public class WebSocketConnection implements DisconnectionRequestListener {
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.retrySpec = retrySpec;
this.messageStream =
messagesManager.getMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice);
this.platformTag = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
}
public void start() {
pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent());
final long queueDrainStartNanos = System.nanoTime();
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
messageSubscription.set(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages())
.name(SEND_MESSAGES_FLUX_NAME)
.tap(Micrometer.metrics(Metrics.globalRegistry))
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
// We want to handle conflicting connections as soon as possible, and so do this before we start processing
// messages in the `flatMapSequential` stage below. If we didn't do this first, then we'd wait for clients to
// process messages before sending the "connected elsewhere" signal, and while that's ultimately not harmful,
// it's also not ideal.
.doOnError(ConflictingMessageConsumerException.class, _ -> {
Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "true")).increment();
client.close(4409, "Connected elsewhere");
})
.doOnNext(entry -> {
if (entry instanceof MessageStreamEntry.Envelope(final Envelope message)) {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI),
authenticatedDevice.getId(),
UUID.fromString(message.getServerGuid()),
client.getUserAgent(),
"websocket");
}
}
})
.flatMapSequential(entry -> switch (entry) {
case MessageStreamEntry.Envelope envelope -> Mono.fromFuture(() -> sendMessage(envelope.message()))
.retryWhen(retrySpec)
.thenReturn(entry);
case MessageStreamEntry.QueueEmpty _ -> Mono.just(entry);
}, MESSAGE_SENDER_MAX_CONCURRENCY)
// `ConflictingMessageConsumerException` is handled before processing messages
.doOnError(throwable -> !(throwable instanceof ConflictingMessageConsumerException), throwable -> {
measureSendMessageErrors(throwable);
if (!client.isOpen()) {
logger.debug("Client disconnected before queue cleared");
return;
}
client.close(1011, "Failed to retrieve messages");
})
// Make sure we process message acknowledgements before sending the "queue clear" signal
.doOnNext(entry -> {
if (entry instanceof MessageStreamEntry.QueueEmpty) {
final Duration drainDuration = Duration.ofNanos(System.nanoTime() - queueDrainStartNanos);
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, platformTag).record(sentMessageCounter.sum());
Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, platformTag).record(drainDuration);
if (drainDuration.compareTo(SLOW_DRAIN_THRESHOLD) > 0) {
Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, platformTag).increment();
}
client.sendRequest("PUT", "/api/v1/queue/empty",
Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty());
}
})
.subscribeOn(messageDeliveryScheduler)
.subscribe());
queueDrainStartNanoTime.set(System.nanoTime());
processStoredMessages();
}
public void stop() {
@@ -258,22 +171,16 @@ public class WebSocketConnection implements DisconnectionRequestListener {
client.close(1000, "OK");
messagesManager.mayHaveMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice)
.thenAccept(mayHaveMessages -> {
if (mayHaveMessages) {
pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount,
authenticatedDevice,
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
}
});
if (storedMessageState.get() != StoredMessageState.EMPTY) {
pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount,
authenticatedDevice,
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
}
}
private CompletableFuture<Void> sendMessage(final Envelope message) {
if (message.getStory() && !client.shouldDeliverStories()) {
return messageStream.acknowledgeMessage(message);
}
final Optional<byte[]> body = Optional.of(serializeMessage(message));
private CompletableFuture<Void> sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) {
// clear ephemeral field from the envelope
final Optional<byte[]> body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray());
sendMessageCounter.increment();
sentMessageCounter.increment();
@@ -301,39 +208,40 @@ public class WebSocketConnection implements DisconnectionRequestListener {
final CompletableFuture<Void> result;
if (isSuccessResponse(response)) {
result = messageStream.acknowledgeMessage(message);
result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice,
storedMessageInfo.guid(), storedMessageInfo.serverTimestamp())
.thenApply(ignored -> null);
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
sendDeliveryReceiptFor(message);
}
} else {
Tags tags = platformTag.and(STATUS_CODE_TAG, String.valueOf(response.getStatus()));
final List<Tag> tags = new ArrayList<>(
List.of(
Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())),
UserAgentTagUtil.getPlatformTag(client.getUserAgent())
));
// TODO Remove this once we've identified the cause of message rejections from desktop clients
if (StringUtils.isNotBlank(response.getMessage())) {
tags = tags.and(Tag.of(STATUS_MESSAGE_TAG, response.getMessage()));
// TODO Remove this once we've identified the cause of message rejections from desktop clients
if (StringUtils.isNotBlank(response.getMessage())) {
tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage()));
}
Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment();
result = CompletableFuture.completedFuture(null);
}
Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment();
result = CompletableFuture.completedFuture(null);
}
return result;
})
.thenRun(() -> sample.stop(Timer.builder(SEND_MESSAGE_DURATION_TIMER_NAME)
.publishPercentileHistogram(true)
.minimumExpectedValue(Duration.ofMillis(100))
.maximumExpectedValue(Duration.ofDays(1))
.tags(platformTag)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())))
.register(Metrics.globalRegistry)));
}
@VisibleForTesting
static byte[] serializeMessage(final Envelope message) {
return message.toBuilder().clearEphemeral().build().toByteArray();
}
private void sendDeliveryReceiptFor(Envelope message) {
if (!message.hasSourceServiceId()) {
return;
@@ -343,17 +251,110 @@ public class WebSocketConnection implements DisconnectionRequestListener {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (final IllegalArgumentException e) {
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
} catch (final Exception e) {
} catch (Exception e) {
logger.warn("Failed to send receipt", e);
}
}
private static boolean isSuccessResponse(final WebSocketResponseMessage response) {
private boolean isSuccessResponse(WebSocketResponseMessage response) {
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
@VisibleForTesting
void processStoredMessages() {
if (processStoredMessagesSemaphore.tryAcquire()) {
final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY);
final boolean cachedMessagesOnly = state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE;
sendMessages(cachedMessagesOnly)
// Update our state with the outcome, send the empty queue message if we need to, and release the semaphore
.whenComplete((ignored, cause) -> {
try {
if (cause != null) {
// We failed, if the state is currently EMPTY, set it to what it was before we tried
storedMessageState.compareAndSet(StoredMessageState.EMPTY, state);
return;
}
// Cleared the queue! Send a queue empty message if we need to
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDurationNanos = System.nanoTime() - queueDrainStartNanoTime.get();
Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum());
Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDurationNanos, TimeUnit.NANOSECONDS);
if (drainDurationNanos > SLOW_DRAIN_THRESHOLD) {
Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment();
}
client.sendRequest("PUT", "/api/v1/queue/empty",
Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty());
}
} finally {
processStoredMessagesSemaphore.release();
}
})
// Potentially kick off more work, must happen after we release the semaphore
.whenComplete((ignored, cause) -> {
if (cause != null) {
if (!client.isOpen()) {
logger.debug("Client disconnected before queue cleared");
return;
}
client.close(1011, "Failed to retrieve messages");
return;
}
// Success, but check if more messages came in while we were processing
if (storedMessageState.get() != StoredMessageState.EMPTY) {
processStoredMessages();
}
});
}
}
private CompletableFuture<Void> sendMessages(final boolean cachedMessagesOnly) {
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final Disposable subscription = Flux.from(messages)
.name(SEND_MESSAGES_FLUX_NAME)
.tap(Micrometer.metrics(Metrics.globalRegistry))
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.doOnNext(envelope -> {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI),
authenticatedDevice.getId(),
UUID.fromString(envelope.getServerGuid()),
client.getUserAgent(),
"websocket");
}
})
.flatMapSequential(envelope -> Mono.fromFuture(() -> sendMessage(envelope))
.retryWhen(Retry.backoff(4, Duration.ofSeconds(1)).filter(throwable -> !isConnectionClosedException(throwable))),
MESSAGE_SENDER_MAX_CONCURRENCY)
.doOnError(this::measureSendMessageErrors)
.subscribeOn(messageDeliveryScheduler)
.subscribe(
// no additional consumer of values - it is Flux<Void> by now
null,
// this first error will terminate the stream, but we may get multiple errors from in-flight messages
queueCleared::completeExceptionally,
// completion
() -> queueCleared.complete(null)
);
messageSubscription.set(subscription);
return queueCleared;
}
private void measureSendMessageErrors(final Throwable e) {
final String errorType;
@@ -366,22 +367,76 @@ public class WebSocketConnection implements DisconnectionRequestListener {
errorType = "other";
}
Metrics.counter(SEND_MESSAGE_ERROR_COUNTER,
platformTag.and(ERROR_TYPE_TAG, errorType, EXCEPTION_TYPE_TAG, e.getClass().getSimpleName()))
Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of(ERROR_TYPE_TAG, errorType),
Tag.of(EXCEPTION_TYPE_TAG, e.getClass().getSimpleName())))
.increment();
}
@VisibleForTesting
static boolean isConnectionClosedException(final Throwable throwable) {
private static boolean isConnectionClosedException(final Throwable throwable) {
return throwable instanceof java.nio.channels.ClosedChannelException ||
throwable == WebSocketResourceProvider.CONNECTION_CLOSED_EXCEPTION ||
throwable instanceof org.eclipse.jetty.io.EofException ||
(throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage()));
}
private CompletableFuture<Void> sendMessage(Envelope envelope) {
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getStory() && !client.shouldDeliverStories()) {
messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp());
return CompletableFuture.completedFuture(null);
} else {
return sendMessage(envelope, new StoredMessageInfo(messageGuid, envelope.getServerTimestamp()));
}
}
@Override
public void handleNewMessageAvailable() {
Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME,
PRESENCE_MANAGER_TAG, "pubsub")
.increment();
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
}
@Override
public void handleMessagesPersisted() {
Metrics.counter(MESSAGES_PERSISTED_COUNTER_NAME,
PRESENCE_MANAGER_TAG, "pubsub")
.increment();
storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
}
@Override
public void handleConflictingMessageConsumer() {
closeConnection(4409, "Connected elsewhere");
}
@Override
public void handleDisconnectionRequest() {
Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "false")).increment();
client.close(4401, "Reauthentication required");
closeConnection(4401, "Reauthentication required");
}
private void closeConnection(final int code, final String message) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
// TODO We should probably just use the status code directly
Tag.of("connectedElsewhere", String.valueOf(code == 4409)));
Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
client.close(code, message);
}
private record StoredMessageInfo(UUID guid, long serverTimestamp) {
}
}