Use error-specific retry mechanisms in WebSocketConnection and associated classes

This commit is contained in:
Jon Chambers
2025-07-31 10:53:11 -04:00
committed by GitHub
parent 8fc0b49994
commit 5c3be9c3d6
8 changed files with 81 additions and 124 deletions

View File

@@ -504,7 +504,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
messageDeliveryQueue);
ScheduledExecutorService recurringJobExecutor = ScheduledExecutorServiceBuilder.of(environment, "recurringJob").threads(6).build();
ScheduledExecutorService websocketScheduledExecutor = ScheduledExecutorServiceBuilder.of(environment, "websocket").threads(8).build();
ExecutorService apnSenderExecutor = ExecutorServiceBuilder.of(environment, "apnSender")
.maxThreads(1).minThreads(1).build();
ExecutorService fcmSenderExecutor = ExecutorServiceBuilder.of(environment, "fcmSender")
@@ -996,7 +995,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, redisMessageAvailabilityManager, disconnectionRequestManager, websocketScheduledExecutor,
pushNotificationScheduler, redisMessageAvailabilityManager, disconnectionRequestManager,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));

View File

@@ -21,7 +21,6 @@ import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class ClusterLuaScript {
@@ -133,13 +132,7 @@ public class ClusterLuaScript {
final T[] keys, final T[] args) {
return connection.reactive().evalsha(sha, scriptOutputType, keys, args)
.onErrorResume(e -> {
if (e instanceof RedisNoScriptException) {
return connection.reactive().eval(script, scriptOutputType, keys, args);
}
log.warn("Failed to execute script", e);
return Mono.error(e);
});
.onErrorResume(RedisNoScriptException.class, _ -> connection.reactive().eval(script, scriptOutputType, keys, args))
.doOnError(throwable -> log.warn("Failed to execute script", throwable));
}
}

View File

@@ -52,6 +52,7 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.RetrySpec;
/**
* Manages short-term storage of messages in Redis. Messages are frequently delivered to their destination and deleted
@@ -521,6 +522,7 @@ public class MessagesCache {
long messageId, int pageSize) {
return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId)
.retryWhen(RetrySpec.backoff(4, Duration.ofSeconds(1)).maxBackoff(Duration.ofSeconds(4)))
.map(queueItems -> {
logger.trace("Processing page: {}", messageId);

View File

@@ -9,7 +9,6 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@@ -19,10 +18,10 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@@ -51,7 +50,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final PushNotificationScheduler pushNotificationScheduler;
private final RedisMessageAvailabilityManager redisMessageAvailabilityManager;
private final DisconnectionRequestManager disconnectionRequestManager;
private final ScheduledExecutorService scheduledExecutorService;
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
@@ -69,7 +67,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final PushNotificationScheduler pushNotificationScheduler,
final RedisMessageAvailabilityManager redisMessageAvailabilityManager,
final DisconnectionRequestManager disconnectionRequestManager,
final ScheduledExecutorService scheduledExecutorService,
final Scheduler messageDeliveryScheduler,
final ClientReleaseManager clientReleaseManager,
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
@@ -83,7 +80,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.pushNotificationScheduler = pushNotificationScheduler;
this.redisMessageAvailabilityManager = redisMessageAvailabilityManager;
this.disconnectionRequestManager = disconnectionRequestManager;
this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
@@ -126,7 +122,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient(),
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager,
messageDeliveryLoopMonitor,

View File

@@ -18,20 +18,15 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.util.StaticException;
import org.reactivestreams.Publisher;
@@ -48,10 +43,10 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -65,6 +60,7 @@ import reactor.core.observability.micrometer.Micrometer;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.util.retry.Retry;
public class WebSocketConnection implements MessageAvailabilityListener, DisconnectionRequestListener {
@@ -80,7 +76,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
"initialQueueLength");
private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue");
private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain");
private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry");
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class,
"clientNonSuccessResponse");
@@ -105,11 +100,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
@VisibleForTesting
static final int MESSAGE_SENDER_MAX_CONCURRENCY = 256;
@VisibleForTesting
static final int MAX_CONSECUTIVE_RETRIES = 5;
private static final long RETRY_DELAY_MILLIS = 1_000;
private static final int RETRY_DELAY_JITTER_MILLIS = 500;
private static final int DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS = 5 * 60 * 1000;
private static final Duration CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY = Duration.ofMinutes(1);
@@ -130,19 +120,14 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
private final int sendFuturesTimeoutMillis;
private final ScheduledExecutorService scheduledExecutorService;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(
StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartTime = new AtomicLong();
private final AtomicInteger consecutiveRetries = new AtomicInteger();
private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>();
private final AtomicReference<Disposable> messageSubscription = new AtomicReference<>();
private final Random random = new Random();
private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager;
@@ -161,7 +146,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
@@ -176,7 +160,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
authenticatedDevice,
client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService,
messageDeliveryScheduler,
clientReleaseManager,
messageDeliveryLoopMonitor, experimentEnrollmentManager);
@@ -192,7 +175,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
Device authenticatedDevice,
WebSocketClient client,
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
@@ -207,7 +189,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
this.authenticatedDevice = authenticatedDevice;
this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
this.scheduledExecutorService = scheduledExecutorService;
this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
@@ -221,12 +202,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
}
public void stop() {
final ScheduledFuture<?> future = retryFuture.get();
if (future != null) {
future.cancel(false);
}
final Disposable subscription = messageSubscription.get();
if (subscription != null) {
subscription.dispose();
@@ -342,7 +317,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
}
// Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
@@ -362,44 +336,25 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
}
})
// Potentially kick off more work, must happen after we release the semaphore
.whenComplete((ignored, cause) -> processMoreIfRequested(cause));
.whenComplete((ignored, cause) -> {
if (cause != null) {
if (!client.isOpen()) {
logger.debug("Client disconnected before queue cleared");
return;
}
client.close(1011, "Failed to retrieve messages");
return;
}
// Success, but check if more messages came in while we were processing
if (storedMessageState.get() != StoredMessageState.EMPTY) {
processStoredMessages();
}
});
}
}
/**
* After processing messages, kick off another processing job if more messages came in or if there was an error
*
* @param cause An error that was encountered when processing the message queue, if there was one
*/
private void processMoreIfRequested(final @Nullable Throwable cause) {
if (cause == null) {
// Success, but check if more messages came in while we were processing
if (storedMessageState.get() != StoredMessageState.EMPTY) {
processStoredMessages();
}
return;
}
if (!client.isOpen()) {
logger.debug("Client disconnected before queue cleared");
return;
}
if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) {
logger.warn("Max consecutive retries exceeded", cause);
client.close(1011, "Failed to retrieve messages");
return;
}
logger.debug("Failed to clear queue", cause);
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
Metrics.counter(QUEUE_DRAIN_RETRY_COUNTER_NAME, tags).increment();
final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS);
retryFuture.set(scheduledExecutorService.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS));
}
private CompletableFuture<Void> sendMessages(final boolean cachedMessagesOnly) {
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
@@ -407,7 +362,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final AtomicBoolean hasErrored = new AtomicBoolean();
final Disposable subscription = Flux.from(messages)
.name(SEND_MESSAGES_FLUX_NAME)
@@ -423,19 +377,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
}
})
.flatMapSequential(envelope ->
Mono.fromFuture(() -> sendMessage(envelope)
.orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS))
.onErrorResume(
// let the first error pass through to terminate the subscription
e -> {
final boolean firstError = !hasErrored.getAndSet(true);
measureSendMessageErrors(e, firstError);
return !firstError;
},
// otherwise just emit nothing
e -> Mono.empty()
), MESSAGE_SENDER_MAX_CONCURRENCY)
Mono.defer(() -> Mono.fromFuture(() -> sendMessage(envelope).orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)))
.doOnError(this::measureSendMessageErrors)
// Note that this will retry both for "send to client" timeouts and failures to delete messages on
// acknowledgement
.retryWhen(Retry.backoff(4, Duration.ofSeconds(1))),
MESSAGE_SENDER_MAX_CONCURRENCY)
.subscribeOn(messageDeliveryScheduler)
.subscribe(
// no additional consumer of values - it is Flux<Void> by now
@@ -450,7 +397,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
return queueCleared;
}
private void measureSendMessageErrors(final Throwable e, final boolean terminal) {
private void measureSendMessageErrors(final Throwable e) {
final String errorType;
if (e instanceof TimeoutException) {
@@ -461,7 +408,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn
(e instanceof StaticException staticException && "Closed".equals(staticException.getMessage()))) {
errorType = "connectionClosed";
} else {
logger.warn(terminal ? "Send message failure terminated stream" : "Send message failed", e);
logger.warn("Send message failed", e);
errorType = "other";
}