From 8fe87b77e4f030205b950d0ffb4e46bccfb488c0 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 14 Aug 2025 12:07:40 -0400 Subject: [PATCH] Wait for message acknowledgement before fetching new messags from Redis/DynamoDB --- .../RedisDynamoDbMessagePublisher.java | 127 ++++++++++++------ .../storage/RedisDynamoDbMessageStream.java | 25 +++- .../websocket/WebSocketConnection.java | 41 +----- .../RedisDynamoDbMessagePublisherTest.java | 41 ++++++ .../RedisDynamoDbMessageStreamTest.java | 6 +- .../WebSocketConnectionIntegrationTest.java | 108 ++++++++++++++- .../websocket/WebSocketConnectionTest.java | 68 +--------- 7 files changed, 256 insertions(+), 160 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java index 485c4e8cc..0a91e984f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisher.java @@ -42,9 +42,21 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow // and Redis. private StoredMessageState storedMessageState = StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE; - // The number of messages the downstream subscriber is ready to receive. This changes in response to new requests from - // the downstream subscriber and gets decremented every time we publish a message. - private long unmetDemand = 0; + // Indicates whether we've sent an "initial queue drain complete" signal to subscribers. This state will transition + // from "not ready" to "pending" as soon as the initial message source subscription completes, and then from "pending" + // to "send" when the signal has been sent. This state will never roll backwards. + private QueueEmptySignalState queueEmptySignalState = QueueEmptySignalState.NOT_READY; + + // The total requested demand from the subscriber across the whole lifetime of this publisher + private long requestedDemand = 0; + + // The total number of signals (new messages or "queue empty" signals) published during the lifetime of this + // publisher. Must not exceed `requestedDemand`. + private long publishedEntries = 0; + + // The total number of published signals that have been acknowledged by the subscriber; to avoid re-reading messages, + // we should never start a new message source subscriber until `acknowledgedEntries` is equal to `publishedEntries` + private long acknowledgedEntries = 0; // Although technically nullable, operation of this publisher really begins once we get a subscriber. This publisher // supports only a single subscriber. @@ -54,15 +66,6 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow // terminated, a publisher cannot be un-terminated. private boolean terminated = false; - // This publisher will emit exactly one "queue empty" signal once the initial contents of the message queue have been - // drained. Once emitted, this flag is set to `true` and will never change again. - private boolean publishedQueueEmptySignal = false; - - // …but we may not be able to send the "queue empty" signal downstream immediately if there's no demand. This flag - // tracks whether we're ready to publish a "queue empty" signal, regardless of whether we've actually sent it. Once - // this flag is set to `true`, it will never change again. - private boolean readyToPublishQueueEmptySignal = false; - // A message source subscriber subscribes to messages from upstream data sources (i.e. DynamoDB and Redis), and this // publisher relays signals the message source subscriber to the downstream subscriber. The message source subscriber // may be null if we're not actively fetching messages from an upstream source and changes every time an upstream @@ -83,6 +86,18 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow EMPTY } + private enum QueueEmptySignalState { + // Indicates that we are not yet ready to send the "initial queue drain complete" signal regardless of outstanding + // demand + NOT_READY, + + // Indicates that we are ready to send the "queue empty" signal as soon as demand is available + PENDING, + + // Indicates that we have sent the "queue empty" signal and must never send it again + SENT + } + /// A message source subscriber subscribes to upstream message source publishers and relays signals to the downstream /// subscriber via the parent `RedisDynamoDbMessagePublisher`. private static class MessageSourceSubscriber extends BaseSubscriber { @@ -95,12 +110,7 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow @Override protected void hookOnSubscribe(final Subscription subscription) { - final long unmetDemand = redisDynamoDbMessagePublisher.getUnmetDemand(); - - // If we already have some unmet demand, pass that on to the upstream publisher immediately on subscribing - if (unmetDemand > 0) { - subscription.request(unmetDemand); - } + redisDynamoDbMessagePublisher.handleMessageSourceSubscribed(subscription); } @Override @@ -188,10 +198,16 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow assert subscriber != null; if (!terminated) { + terminate(); subscriber.onError(new ConflictingMessageConsumerException()); } + } - terminate(); + synchronized void handleMessageAcknowledged() { + acknowledgedEntries += 1; + assert acknowledgedEntries <= publishedEntries; + + maybeGenerateMessageSource(); } private synchronized boolean maybeSendQueueEmptySignal() { @@ -203,13 +219,18 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow // The machinery that produces messages won't activate until we have a subscriber assert subscriber != null; - if (readyToPublishQueueEmptySignal && !publishedQueueEmptySignal && getUnmetDemand() > 0) { + if (queueEmptySignalState == QueueEmptySignalState.PENDING && publishedEntries < requestedDemand) { + queueEmptySignalState = QueueEmptySignalState.SENT; + publishedEntries += 1; + + // Subscribers don't explicitly acknowledge "queue empty" signals, and we can consider them automatically + // acknowledged + acknowledgedEntries += 1; + subscriber.onNext(new MessageStreamEntry.QueueEmpty()); - unmetDemand -= 1; - assert unmetDemand >= 0; - - publishedQueueEmptySignal = true; + assert publishedEntries <= requestedDemand; + assert acknowledgedEntries <= publishedEntries; return true; } @@ -218,14 +239,27 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow } private synchronized void maybeGenerateMessageSource() { - // Regardless of any other state, don't do anything if terminated if (terminated) { + // Regardless of any other state, don't do anything if terminated return; } - if (storedMessageState == StoredMessageState.EMPTY || unmetDemand == 0) { - // We don't think there are any messages in either source or there's no demand for messages; either way, wait for - // things to change before trying to generate a message source + if (storedMessageState == StoredMessageState.EMPTY) { + // We don't think there are any messages in either message source; don't do anything until the situation changes + // (when new messages arrive, we'll come back to this point with a non-empty stored message state) + return; + } + + if (publishedEntries == requestedDemand) { + // Even if there are messages available, there's no demand for them yet (when there's new demand, we'll come back + // to this point with a higher value for `requestedDemand` via `addDemand`) + return; + } + + if (acknowledgedEntries < publishedEntries) { + // To avoid double-reading messages from data stores that don't support cursors, don't get a new message source + // unless all previously-published signals have been acknowledged (when messages are acknowledged, we'll come back + // to this point with a higher value for `acknowledgedEntries` via `handleMessageAcknowledged`) return; } @@ -249,14 +283,25 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow storedMessageState = StoredMessageState.EMPTY; } + private synchronized void handleMessageSourceSubscribed(final Subscription subscription) { + if (!terminated) { + // If we already have some unmet demand, pass that on to the upstream publisher immediately on subscribing + if (requestedDemand > publishedEntries) { + subscription.request(requestedDemand - publishedEntries); + } + } + } + private synchronized void handleNextMessage(final MessageProtos.Envelope message) { // The machinery that produces messages won't activate until we have a subscriber assert subscriber != null; if (!terminated) { - unmetDemand -= 1; - assert unmetDemand >= 0; + // We only pass along unfulfilled demand to the message source subscriber, so if the message source subscriber + // emits a new signal, it should fit within the unfulfilled demand from the downstream subscriber + assert publishedEntries < requestedDemand; + publishedEntries += 1; subscriber.onNext(new MessageStreamEntry.Envelope(message)); } } @@ -267,9 +312,10 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow messageSourceSubscriber = null; - // Attempt to send a "queue empty" signal if we haven't already - readyToPublishQueueEmptySignal = true; - maybeSendQueueEmptySignal(); + if (queueEmptySignalState == QueueEmptySignalState.NOT_READY) { + queueEmptySignalState = QueueEmptySignalState.PENDING; + maybeSendQueueEmptySignal(); + } // New messages may have arrived already; fetch them if possible maybeGenerateMessageSource(); @@ -280,8 +326,8 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow assert subscriber != null; if (!terminated) { - subscriber.onError(throwable); terminate(); + subscriber.onError(throwable); } } @@ -290,9 +336,8 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow throw new IllegalArgumentException("Demand must be positive"); } - unmetDemand += demand; + requestedDemand += demand; - // We may have been waiting for non-zero demand before sending a "queue empty" signal final boolean sentQueueEmptySignal = maybeSendQueueEmptySignal(); // This is a little tricky; if we already have a subscriber, we only want to request NEW demand, not the total @@ -308,20 +353,16 @@ class RedisDynamoDbMessagePublisher implements MessageAvailabilityListener, Flow } } - private synchronized long getUnmetDemand() { - return unmetDemand; - } - private synchronized void terminate() { if (!terminated) { terminated = true; - // Stop receiving signals about new messages/conflicting consumers - redisMessageAvailabilityManager.handleClientDisconnected(accountIdentifier, device.getId()); - if (messageSourceSubscriber != null) { messageSourceSubscriber.dispose(); } + + // Stop receiving signals about new messages/conflicting consumers + redisMessageAvailabilityManager.handleClientDisconnected(accountIdentifier, device.getId()); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStream.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStream.java index c0dbaed66..125182aa9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStream.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStream.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.storage; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Flow; +import com.google.common.annotations.VisibleForTesting; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.util.Util; @@ -29,16 +30,25 @@ public class RedisDynamoDbMessageStream implements MessageStream { final UUID accountIdentifier, final Device device) { + this(messagesDynamoDb, messagesCache, accountIdentifier, device, new RedisDynamoDbMessagePublisher(messagesDynamoDb, + messagesCache, + redisMessageAvailabilityManager, + accountIdentifier, + device)); + } + + @VisibleForTesting + RedisDynamoDbMessageStream(final MessagesDynamoDb messagesDynamoDb, + final MessagesCache messagesCache, + final UUID accountIdentifier, + final Device device, + final RedisDynamoDbMessagePublisher messagePublisher) { + this.messagesDynamoDb = messagesDynamoDb; this.messagesCache = messagesCache; this.accountIdentifier = accountIdentifier; this.device = device; - - this.messagePublisher = new RedisDynamoDbMessagePublisher(messagesDynamoDb, - messagesCache, - redisMessageAvailabilityManager, - accountIdentifier, - device); + this.messagePublisher = messagePublisher; } @Override @@ -54,6 +64,7 @@ public class RedisDynamoDbMessageStream implements MessageStream { .thenCompose(removed -> removed.map(_ -> CompletableFuture.completedFuture(null)) .orElseGet(() -> messagesDynamoDb.deleteMessage(accountIdentifier, device, guid, message.getServerTimestamp()) - .thenRun(Util.NOOP))); + .thenRun(Util.NOOP))) + .whenComplete((_, _) -> messagePublisher.handleMessageAcknowledged()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 3665cd2c3..6abf7fd4d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -56,7 +56,6 @@ import reactor.core.Disposable; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; -import reactor.util.retry.Retry; public class WebSocketConnection implements DisconnectionRequestListener { @@ -103,8 +102,6 @@ public class WebSocketConnection implements DisconnectionRequestListener { private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final ExperimentEnrollmentManager experimentEnrollmentManager; - private final Retry retrySpec; - private final Account authenticatedAccount; private final Device authenticatedDevice; private final MessageStream messageStream; @@ -131,38 +128,6 @@ public class WebSocketConnection implements DisconnectionRequestListener { final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { - this(receiptSender, - messagesManager, - messageMetrics, - pushNotificationManager, - pushNotificationScheduler, - authenticatedAccount, - authenticatedDevice, - client, - messageDeliveryScheduler, - clientReleaseManager, - messageDeliveryLoopMonitor, - experimentEnrollmentManager, - Retry.backoff(4, Duration.ofSeconds(1)) - .maxBackoff(Duration.ofSeconds(2)) - .filter(throwable -> !isConnectionClosedException(throwable))); - } - - @VisibleForTesting - WebSocketConnection(final ReceiptSender receiptSender, - final MessagesManager messagesManager, - final MessageMetrics messageMetrics, - final PushNotificationManager pushNotificationManager, - final PushNotificationScheduler pushNotificationScheduler, - final Account authenticatedAccount, - final Device authenticatedDevice, - final WebSocketClient client, - final Scheduler messageDeliveryScheduler, - final ClientReleaseManager clientReleaseManager, - final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, - final ExperimentEnrollmentManager experimentEnrollmentManager, - final Retry retrySpec) { - this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; @@ -176,8 +141,6 @@ public class WebSocketConnection implements DisconnectionRequestListener { this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.experimentEnrollmentManager = experimentEnrollmentManager; - this.retrySpec = retrySpec; - this.messageStream = messagesManager.getMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice); @@ -214,9 +177,7 @@ public class WebSocketConnection implements DisconnectionRequestListener { } }) .flatMapSequential(entry -> switch (entry) { - case MessageStreamEntry.Envelope envelope -> Mono.fromFuture(() -> sendMessage(envelope.message())) - .retryWhen(retrySpec) - .thenReturn(entry); + case MessageStreamEntry.Envelope envelope -> Mono.fromFuture(() -> sendMessage(envelope.message())).thenReturn(entry); case MessageStreamEntry.QueueEmpty _ -> Mono.just(entry); }, MESSAGE_SENDER_MAX_CONCURRENCY) .subscribeOn(messageDeliveryScheduler) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java index 0a8a02026..b3a7627da 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessagePublisherTest.java @@ -206,7 +206,10 @@ class RedisDynamoDbMessagePublisherTest { } deleteRedisMessage(redisMessage); + messagePublisher.handleMessageAcknowledged(); + deleteDynamoDbMessage(dynamoDbMessage); + messagePublisher.handleMessageAcknowledged(); insertRedisMessage(newArrivalRedisMessage); messagePublisher.handleNewMessageAvailable(); @@ -245,7 +248,10 @@ class RedisDynamoDbMessagePublisherTest { } deleteRedisMessage(redisMessage); + messagePublisher.handleMessageAcknowledged(); + deleteDynamoDbMessage(dynamoDbMessage); + messagePublisher.handleMessageAcknowledged(); insertDynamoDbMessage(persistedMessage); messagePublisher.handleMessagesPersisted(); @@ -264,6 +270,41 @@ class RedisDynamoDbMessagePublisherTest { .verifyTimeout(Duration.ofMillis(500)); } + @Test + void publishMessagesWaitForAcknowledgement() { + final MessageProtos.Envelope dynamoDbMessage = insertDynamoDbMessage(generateRandomMessage()); + final MessageProtos.Envelope redisMessage = insertRedisMessage(generateRandomMessage()); + + final MessageProtos.Envelope persistedMessage = generateRandomMessage(); + + final RedisDynamoDbMessagePublisher messagePublisher = + new RedisDynamoDbMessagePublisher(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, DESTINATION_SERVICE_IDENTIFIER.uuid(), destinationDevice); + + final CountDownLatch queueEmptyCountDownLatch = new CountDownLatch(1); + + Thread.ofVirtual().start(() -> { + try { + queueEmptyCountDownLatch.await(); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + + insertDynamoDbMessage(persistedMessage); + messagePublisher.handleMessagesPersisted(); + }); + + StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messagePublisher) + .doOnNext(entry -> { + if (entry instanceof MessageStreamEntry.QueueEmpty) { + queueEmptyCountDownLatch.countDown(); + } + })) + .expectNext(new MessageStreamEntry.Envelope(dynamoDbMessage)) + .expectNext(new MessageStreamEntry.Envelope(redisMessage)) + .expectNext(new MessageStreamEntry.QueueEmpty()) + .verifyTimeout(Duration.ofMillis(500)); + } + @Test void publishMessagesConsumerConflict() { final RedisDynamoDbMessagePublisher messagePublisher = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStreamTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStreamTest.java index 3bd3e802c..d19625a02 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStreamTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisDynamoDbMessageStreamTest.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; class RedisDynamoDbMessageStreamTest { @@ -44,9 +43,9 @@ class RedisDynamoDbMessageStreamTest { redisDynamoDbMessageStream = new RedisDynamoDbMessageStream(messagesDynamoDb, messagesCache, - mock(RedisMessageAvailabilityManager.class), ACCOUNT_IDENTIFIER, - device); + device, + mock(RedisDynamoDbMessagePublisher.class)); when(messagesDynamoDb.deleteMessage(any(), any(), any(), anyLong())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); @@ -74,7 +73,6 @@ class RedisDynamoDbMessageStreamTest { void acknowledgeMessageRedis() { final MessageProtos.Envelope message = generateMessage(); final UUID messageGuid = UUID.fromString(message.getServerGuid()); - final long serverTimestamp = message.getServerTimestamp(); when(messagesCache.remove(ACCOUNT_IDENTIFIER, DEVICE_ID, messageGuid)) .thenReturn(CompletableFuture.completedFuture(Optional.of(RemovedMessage.fromEnvelope(message)))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 472056928..718cf2bfb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -34,6 +34,8 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -67,6 +69,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; @@ -126,8 +129,13 @@ class WebSocketConnectionIntegrationTest { redisMessageAvailabilityManager.stop(); sharedExecutorService.shutdown(); + final Mono schedulerShutdownMono = messageDeliveryScheduler.disposeGracefully(); + //noinspection ResultOfMethodCallIgnored sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS); + schedulerShutdownMono.timeout(Duration.ofSeconds(2)) + .onErrorResume(TimeoutException.class, _ -> Mono.fromRunnable(() -> messageDeliveryScheduler.dispose())) + .block(); } @ParameterizedTest @@ -210,6 +218,101 @@ class WebSocketConnectionIntegrationTest { }); } + @Test + void testProcessStoredMessagesMultipleSegments() { + final WebSocketConnection webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, reportMessageManager, sharedExecutorService, Clock.systemUTC()), + new MessageMetrics(), + mock(PushNotificationManager.class), + mock(PushNotificationScheduler.class), + account, + device, + webSocketClient, + messageDeliveryScheduler, + clientReleaseManager, + mock(MessageDeliveryLoopMonitor.class), + mock(ExperimentEnrollmentManager.class) + ); + + final int persistedMessageCount = 77; + final int cachedMessageCount = 104; + + final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); + + assertTimeoutPreemptively(Duration.ofSeconds(15), () -> { + + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); + + for (int i = 0; i < persistedMessageCount; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + + persistedMessages.add(envelope); + expectedMessages.add(envelope); + } + + messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device); + } + + for (int i = 0; i < cachedMessageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + + messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); + expectedMessages.add(envelope); + } + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + + final AtomicInteger remainingMessages = new AtomicInteger(persistedMessageCount + cachedMessageCount); + final int additionalMessageCount = 67; + + when(successResponse.getStatus()).thenReturn(200); + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) + .thenAnswer(_ -> { + if (remainingMessages.addAndGet(-1) == 60) { + sharedExecutorService.submit(() -> { + for (int i = 0; i < additionalMessageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + + messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); + expectedMessages.add(envelope); + } + }); + } + + return CompletableFuture.completedFuture(successResponse); + }); + + webSocketConnection.start(); + + @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = + ArgumentCaptor.forClass(Optional.class); + + verify(webSocketClient, timeout(10_000)) + .sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + + verify(webSocketClient, timeout(10_000).times(persistedMessageCount + cachedMessageCount + additionalMessageCount)) + .sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + + final List sentMessages = new ArrayList<>(); + + for (final Optional maybeMessageBody : messageBodyCaptor.getAllValues()) { + maybeMessageBody.ifPresent(messageBytes -> { + try { + sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes)); + } catch (final InvalidProtocolBufferException e) { + fail("Could not parse sent message"); + } + }); + } + + assertEquals(expectedMessages, sentMessages); + }); + } + @Test void testProcessStoredMessagesClientClosed() { final WebSocketConnection webSocketConnection = new WebSocketConnection( @@ -254,8 +357,8 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( - CompletableFuture.failedFuture(new IOException("Connection closed"))); + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) + .thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed"))); webSocketConnection.start(); @@ -293,5 +396,4 @@ class WebSocketConnectionIntegrationTest { .setDestinationServiceId(UUID.randomUUID().toString()) .build(); } - } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 94958c104..d2f962f26 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -27,7 +27,6 @@ import io.lettuce.core.RedisException; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Arrays; -import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -66,7 +65,6 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; -import reactor.util.retry.Retry; class WebSocketConnectionTest { @@ -77,11 +75,6 @@ class WebSocketConnectionTest { private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; - private static final int MAX_RETRIES = 3; - private static final Retry RETRY_SPEC = Retry.backoff(MAX_RETRIES, Duration.ofMillis(5)) - .maxBackoff(Duration.ofMillis(20)) - .filter(throwable -> !WebSocketConnection.isConnectionClosedException(throwable)); - private static final int SOURCE_DEVICE_ID = 1; private static final AtomicInteger ON_ERROR_DROPPED_COUNTER = new AtomicInteger(); @@ -129,8 +122,7 @@ class WebSocketConnectionTest { Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class), - RETRY_SPEC); + mock(ExperimentEnrollmentManager.class)); } @Test @@ -249,14 +241,14 @@ class WebSocketConnectionTest { verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage)))); - verify(client, timeout(500).times(MAX_RETRIES + 1)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage)))); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + verify(client, never()).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage)))); verify(messageStream).acknowledgeMessage(successfulMessage); - verify(messageStream).acknowledgeMessage(secondSuccessfulMessage); + verify(messageStream, never()).acknowledgeMessage(secondSuccessfulMessage); verify(receiptSender) .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), @@ -270,7 +262,7 @@ class WebSocketConnectionTest { AciServiceIdentifier.valueOf(failedMessage.getSourceServiceId()), failedMessage.getClientTimestamp()); - verify(receiptSender) + verify(receiptSender, never()) .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), deviceId, AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()), @@ -574,56 +566,6 @@ class WebSocketConnectionTest { .verify(); } - @Test - void testRetryOnError() { - final UUID accountIdentifier = UUID.randomUUID(); - - final List outgoingMessages = List.of(createMessage(accountIdentifier, accountIdentifier, 1111, "first")); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(outgoingMessages) - .map(MessageStreamEntry.Envelope::new))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketClient client = mock(WebSocketClient.class); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) - .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - final WebSocketConnection connection = buildWebSocketConnection(client); - - connection.start(); - - verify(client, timeout(500).times(3)) - .sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); - - verify(messageStream, timeout(500)).acknowledgeMessage(outgoingMessages.getFirst()); - - verify(receiptSender, timeout(500)) - .sendReceipt(new AciServiceIdentifier(accountIdentifier), deviceId, new AciServiceIdentifier(accountIdentifier), 1111L); - - connection.stop(); - verify(client).close(eq(1000), anyString()); - } - private static Envelope createMessage(final UUID senderUuid, final UUID destinationUuid, final long timestamp,