diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2155d6f12..2e1495810 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -994,8 +994,9 @@ public class WhisperServerService extends Application new WebSocketConnection(receiptSender, + messagesManager, + messageMetrics, + pushNotificationManager, + pushNotificationScheduler, + account, + device, + client, + messageDeliveryScheduler, + clientReleaseManager, + messageDeliveryLoopMonitor, + experimentEnrollmentManager), + authenticated -> new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, + NEW_CONNECTION_COUNTER_NAME, + CONNECTED_DURATION_TIMER_NAME, + Duration.ofHours(3), + Tags.of(AUTHENTICATED_TAG_NAME, String.valueOf(authenticated))) + ); + } + + @VisibleForTesting AuthenticatedConnectListener( + final AccountsManager accountsManager, + final DisconnectionRequestManager disconnectionRequestManager, + final WebSocketConnectionBuilder webSocketConnectionBuilder, + final Function openWebSocketCounterBuilder) { + this.accountsManager = accountsManager; - this.receiptSender = receiptSender; - this.messagesManager = messagesManager; - this.messageMetrics = messageMetrics; - this.pushNotificationManager = pushNotificationManager; - this.pushNotificationScheduler = pushNotificationScheduler; - this.redisMessageAvailabilityManager = redisMessageAvailabilityManager; this.disconnectionRequestManager = disconnectionRequestManager; - this.messageDeliveryScheduler = messageDeliveryScheduler; - this.clientReleaseManager = clientReleaseManager; - this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; - this.experimentEnrollmentManager = experimentEnrollmentManager; + this.webSocketConnectionBuilder = webSocketConnectionBuilder; - openAuthenticatedWebSocketCounter = - new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "true")); - - openUnauthenticatedWebSocketCounter = - new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "false")); + openAuthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(true); + openUnauthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(false); } @Override public void onWebSocketConnect(final WebSocketSessionContext context) { final boolean authenticated = (context.getAuthenticated() != null); - final OpenWebSocketCounter openWebSocketCounter = - authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter; - openWebSocketCounter.countOpenWebSocket(context); + (authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context); if (authenticated) { final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); - final Optional maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.accountIdentifier()); - final Optional maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); + final Optional maybeAuthenticatedAccount = + accountsManager.getByAccountIdentifier(auth.accountIdentifier()); + + final Optional maybeAuthenticatedDevice = + maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId()); @@ -116,18 +130,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { return; } - final WebSocketConnection connection = new WebSocketConnection(receiptSender, - messagesManager, - messageMetrics, - pushNotificationManager, - pushNotificationScheduler, - maybeAuthenticatedAccount.get(), - maybeAuthenticatedDevice.get(), - context.getClient(), - messageDeliveryScheduler, - clientReleaseManager, - messageDeliveryLoopMonitor, - experimentEnrollmentManager); + final WebSocketConnection connection = + webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(), + maybeAuthenticatedDevice.get(), + context.getClient()); disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), maybeAuthenticatedDevice.get().getId(), @@ -138,27 +144,11 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { maybeAuthenticatedDevice.get().getId(), connection); - // We begin the shutdown process by removing this client's "presence," which means it will again begin to - // receive push notifications for inbound messages. We should do this first because, at this point, the - // connection has already closed and attempts to actually deliver a message via the connection will not succeed. - // It's preferable to start sending push notifications as soon as possible. - redisMessageAvailabilityManager.handleClientDisconnected(auth.accountIdentifier(), auth.deviceId()); - - // Finally, stop trying to deliver messages and send a push notification if the connection is aware of any - // undelivered messages. connection.stop(); }); try { - // Once we "start" the websocket connection, we'll cancel any scheduled "you may have new messages" push - // notifications and begin delivering any stored messages for the connected device. We have not yet declared the - // client as "present" yet. If a message arrives at this point, we will update the message availability state - // correctly, but we may also send a spurious push notification. connection.start(); - - // Finally, we register this client's presence, which suppresses push notifications. We do this last because - // receiving extra push notifications is generally preferable to missing out on a push notification. - redisMessageAvailabilityManager.handleClientConnected(auth.accountIdentifier(), auth.deviceId(), connection); } catch (final Exception e) { log.warn("Failed to initialize websocket", e); context.getClient().close(1011, "Unexpected error initializing connection"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index e4482123e..fded9d827 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -14,22 +14,17 @@ import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; import java.time.Duration; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import org.apache.commons.lang3.StringUtils; import org.eclipse.jetty.util.StaticException; -import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; @@ -42,26 +37,28 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; -import org.whispersystems.textsecuregcm.push.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.storage.ConflictingMessageConsumerException; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.MessageStream; +import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import reactor.adapter.JdkFlowAdapter; import reactor.core.Disposable; import reactor.core.observability.micrometer.Micrometer; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.util.retry.Retry; -public class WebSocketConnection implements MessageAvailabilityListener, DisconnectionRequestListener { +public class WebSocketConnection implements DisconnectionRequestListener { private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage")); private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent")); @@ -78,17 +75,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn "sendMessages"); private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, "sendMessageError"); - private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable"); - private static final String MESSAGES_PERSISTED_COUNTER_NAME = name(WebSocketConnection.class, "messagesPersisted"); private static final String SEND_MESSAGE_DURATION_TIMER_NAME = name(WebSocketConnection.class, "sendMessageDuration"); - private static final String PRESENCE_MANAGER_TAG = "presenceManager"; private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; private static final String ERROR_TYPE_TAG = "errorType"; private static final String EXCEPTION_TYPE_TAG = "exceptionType"; + private static final String CONNECTED_ELSEWHERE_TAG = "connectedElsewhere"; - private static final long SLOW_DRAIN_THRESHOLD = 10_000; + private static final Duration SLOW_DRAIN_THRESHOLD = Duration.ofSeconds(10); @VisibleForTesting static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100; @@ -108,28 +103,21 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final ExperimentEnrollmentManager experimentEnrollmentManager; + private final Retry retrySpec; + private final Account authenticatedAccount; private final Device authenticatedDevice; + private final MessageStream messageStream; private final WebSocketClient client; + private final Tags platformTag; - private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); - private final AtomicReference storedMessageState = new AtomicReference<>( - StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); - private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); private final LongAdder sentMessageCounter = new LongAdder(); - private final AtomicLong queueDrainStartNanoTime = new AtomicLong(); private final AtomicReference messageSubscription = new AtomicReference<>(); private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; - private enum StoredMessageState { - EMPTY, - CACHED_NEW_MESSAGES_AVAILABLE, - PERSISTED_NEW_MESSAGES_AVAILABLE - } - public WebSocketConnection(final ReceiptSender receiptSender, final MessagesManager messagesManager, final MessageMetrics messageMetrics, @@ -143,6 +131,38 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { + this(receiptSender, + messagesManager, + messageMetrics, + pushNotificationManager, + pushNotificationScheduler, + authenticatedAccount, + authenticatedDevice, + client, + messageDeliveryScheduler, + clientReleaseManager, + messageDeliveryLoopMonitor, + experimentEnrollmentManager, + Retry.backoff(4, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(2)) + .filter(throwable -> !isConnectionClosedException(throwable))); + } + + @VisibleForTesting + WebSocketConnection(final ReceiptSender receiptSender, + final MessagesManager messagesManager, + final MessageMetrics messageMetrics, + final PushNotificationManager pushNotificationManager, + final PushNotificationScheduler pushNotificationScheduler, + final Account authenticatedAccount, + final Device authenticatedDevice, + final WebSocketClient client, + final Scheduler messageDeliveryScheduler, + final ClientReleaseManager clientReleaseManager, + final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, + final ExperimentEnrollmentManager experimentEnrollmentManager, + final Retry retrySpec) { + this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; @@ -155,12 +175,79 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn this.clientReleaseManager = clientReleaseManager; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.experimentEnrollmentManager = experimentEnrollmentManager; + + this.retrySpec = retrySpec; + + this.messageStream = + messagesManager.getMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice); + + this.platformTag = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); } public void start() { pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent()); - queueDrainStartNanoTime.set(System.nanoTime()); - processStoredMessages(); + + final long queueDrainStartNanos = System.nanoTime(); + final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); + + messageSubscription.set(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()) + .name(SEND_MESSAGES_FLUX_NAME) + .tap(Micrometer.metrics(Metrics.globalRegistry)) + .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) + // We want to handle conflicting connections as soon as possible, and so do this before we start processing + // messages in the `flatMapSequential` stage below. If we didn't do this first, then we'd wait for clients to + // process messages before sending the "connected elsewhere" signal, and while that's ultimately not harmful, + // it's also not ideal. + .doOnError(ConflictingMessageConsumerException.class, _ -> { + Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "true")).increment(); + client.close(4409, "Connected elsewhere"); + }) + .doOnNext(entry -> { + if (entry instanceof MessageStreamEntry.Envelope(final Envelope message)) { + if (hasSentFirstMessage.compareAndSet(false, true)) { + messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI), + authenticatedDevice.getId(), + UUID.fromString(message.getServerGuid()), + client.getUserAgent(), + "websocket"); + } + } + }) + .flatMapSequential(entry -> switch (entry) { + case MessageStreamEntry.Envelope envelope -> Mono.fromFuture(() -> sendMessage(envelope.message())) + .retryWhen(retrySpec) + .thenReturn(entry); + case MessageStreamEntry.QueueEmpty _ -> Mono.just(entry); + }, MESSAGE_SENDER_MAX_CONCURRENCY) + // `ConflictingMessageConsumerException` is handled before processing messages + .doOnError(throwable -> !(throwable instanceof ConflictingMessageConsumerException), throwable -> { + measureSendMessageErrors(throwable); + + if (!client.isOpen()) { + logger.debug("Client disconnected before queue cleared"); + return; + } + + client.close(1011, "Failed to retrieve messages"); + }) + // Make sure we process message acknowledgements before sending the "queue clear" signal + .doOnNext(entry -> { + if (entry instanceof MessageStreamEntry.QueueEmpty) { + final Duration drainDuration = Duration.ofNanos(System.nanoTime() - queueDrainStartNanos); + + Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, platformTag).record(sentMessageCounter.sum()); + Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, platformTag).record(drainDuration); + + if (drainDuration.compareTo(SLOW_DRAIN_THRESHOLD) > 0) { + Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, platformTag).increment(); + } + + client.sendRequest("PUT", "/api/v1/queue/empty", + Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); + } + }) + .subscribeOn(messageDeliveryScheduler) + .subscribe()); } public void stop() { @@ -171,16 +258,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn client.close(1000, "OK"); - if (storedMessageState.get() != StoredMessageState.EMPTY) { - pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount, - authenticatedDevice, - CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); - } + messagesManager.mayHaveMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice) + .thenAccept(mayHaveMessages -> { + if (mayHaveMessages) { + pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount, + authenticatedDevice, + CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); + } + }); } - private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { - // clear ephemeral field from the envelope - final Optional body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray()); + private CompletableFuture sendMessage(final Envelope message) { + if (message.getStory() && !client.shouldDeliverStories()) { + return messageStream.acknowledgeMessage(message); + } + + final Optional body = Optional.of(serializeMessage(message)); sendMessageCounter.increment(); sentMessageCounter.increment(); @@ -208,40 +301,39 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn final CompletableFuture result; if (isSuccessResponse(response)) { - result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, - storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) - .thenApply(ignored -> null); + result = messageStream.acknowledgeMessage(message); if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { sendDeliveryReceiptFor(message); } } else { - final List tags = new ArrayList<>( - List.of( - Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())), - UserAgentTagUtil.getPlatformTag(client.getUserAgent()) - )); + Tags tags = platformTag.and(STATUS_CODE_TAG, String.valueOf(response.getStatus())); - // TODO Remove this once we've identified the cause of message rejections from desktop clients - if (StringUtils.isNotBlank(response.getMessage())) { - tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); - } - - Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); - - result = CompletableFuture.completedFuture(null); + // TODO Remove this once we've identified the cause of message rejections from desktop clients + if (StringUtils.isNotBlank(response.getMessage())) { + tags = tags.and(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); } + Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); + + result = CompletableFuture.completedFuture(null); + } + return result; }) .thenRun(() -> sample.stop(Timer.builder(SEND_MESSAGE_DURATION_TIMER_NAME) .publishPercentileHistogram(true) .minimumExpectedValue(Duration.ofMillis(100)) .maximumExpectedValue(Duration.ofDays(1)) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))) + .tags(platformTag) .register(Metrics.globalRegistry))); } + @VisibleForTesting + static byte[] serializeMessage(final Envelope message) { + return message.toBuilder().clearEphemeral().build().toByteArray(); + } + private void sendDeliveryReceiptFor(Envelope message) { if (!message.hasSourceServiceId()) { return; @@ -251,110 +343,17 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()), authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()), message.getClientTimestamp()); - } catch (IllegalArgumentException e) { + } catch (final IllegalArgumentException e) { logger.error("Could not parse UUID: {}", message.getSourceServiceId()); - } catch (Exception e) { + } catch (final Exception e) { logger.warn("Failed to send receipt", e); } } - private boolean isSuccessResponse(WebSocketResponseMessage response) { + private static boolean isSuccessResponse(final WebSocketResponseMessage response) { return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } - @VisibleForTesting - void processStoredMessages() { - if (processStoredMessagesSemaphore.tryAcquire()) { - final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); - final boolean cachedMessagesOnly = state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE; - sendMessages(cachedMessagesOnly) - // Update our state with the outcome, send the empty queue message if we need to, and release the semaphore - .whenComplete((ignored, cause) -> { - try { - if (cause != null) { - // We failed, if the state is currently EMPTY, set it to what it was before we tried - storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); - return; - } - - // Cleared the queue! Send a queue empty message if we need to - if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); - final long drainDurationNanos = System.nanoTime() - queueDrainStartNanoTime.get(); - - Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); - Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDurationNanos, TimeUnit.NANOSECONDS); - - if (drainDurationNanos > SLOW_DRAIN_THRESHOLD) { - Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); - } - - client.sendRequest("PUT", "/api/v1/queue/empty", - Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); - } - } finally { - processStoredMessagesSemaphore.release(); - } - }) - // Potentially kick off more work, must happen after we release the semaphore - .whenComplete((ignored, cause) -> { - if (cause != null) { - if (!client.isOpen()) { - logger.debug("Client disconnected before queue cleared"); - return; - } - - client.close(1011, "Failed to retrieve messages"); - return; - } - - // Success, but check if more messages came in while we were processing - if (storedMessageState.get() != StoredMessageState.EMPTY) { - processStoredMessages(); - } - }); - } - } - - private CompletableFuture sendMessages(final boolean cachedMessagesOnly) { - final CompletableFuture queueCleared = new CompletableFuture<>(); - - final Publisher messages = - messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly); - - final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); - - final Disposable subscription = Flux.from(messages) - .name(SEND_MESSAGES_FLUX_NAME) - .tap(Micrometer.metrics(Metrics.globalRegistry)) - .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) - .doOnNext(envelope -> { - if (hasSentFirstMessage.compareAndSet(false, true)) { - messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI), - authenticatedDevice.getId(), - UUID.fromString(envelope.getServerGuid()), - client.getUserAgent(), - "websocket"); - } - }) - .flatMapSequential(envelope -> Mono.fromFuture(() -> sendMessage(envelope)) - .retryWhen(Retry.backoff(4, Duration.ofSeconds(1)).filter(throwable -> !isConnectionClosedException(throwable))), - MESSAGE_SENDER_MAX_CONCURRENCY) - .doOnError(this::measureSendMessageErrors) - .subscribeOn(messageDeliveryScheduler) - .subscribe( - // no additional consumer of values - it is Flux by now - null, - // this first error will terminate the stream, but we may get multiple errors from in-flight messages - queueCleared::completeExceptionally, - // completion - () -> queueCleared.complete(null) - ); - - messageSubscription.set(subscription); - return queueCleared; - } - private void measureSendMessageErrors(final Throwable e) { final String errorType; @@ -367,76 +366,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn errorType = "other"; } - Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, Tags.of( - UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - Tag.of(ERROR_TYPE_TAG, errorType), - Tag.of(EXCEPTION_TYPE_TAG, e.getClass().getSimpleName()))) + Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, + platformTag.and(ERROR_TYPE_TAG, errorType, EXCEPTION_TYPE_TAG, e.getClass().getSimpleName())) .increment(); } - private static boolean isConnectionClosedException(final Throwable throwable) { + @VisibleForTesting + static boolean isConnectionClosedException(final Throwable throwable) { return throwable instanceof java.nio.channels.ClosedChannelException || throwable == WebSocketResourceProvider.CONNECTION_CLOSED_EXCEPTION || throwable instanceof org.eclipse.jetty.io.EofException || (throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage())); } - private CompletableFuture sendMessage(Envelope envelope) { - final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); - - if (envelope.getStory() && !client.shouldDeliverStories()) { - messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp()); - - return CompletableFuture.completedFuture(null); - } else { - return sendMessage(envelope, new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())); - } - } - - @Override - public void handleNewMessageAvailable() { - Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME, - PRESENCE_MANAGER_TAG, "pubsub") - .increment(); - - storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); - - processStoredMessages(); - } - - @Override - public void handleMessagesPersisted() { - Metrics.counter(MESSAGES_PERSISTED_COUNTER_NAME, - PRESENCE_MANAGER_TAG, "pubsub") - .increment(); - - storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); - - processStoredMessages(); - } - - @Override - public void handleConflictingMessageConsumer() { - closeConnection(4409, "Connected elsewhere"); - } - @Override public void handleDisconnectionRequest() { - closeConnection(4401, "Reauthentication required"); - } - - private void closeConnection(final int code, final String message) { - final Tags tags = Tags.of( - UserAgentTagUtil.getPlatformTag(client.getUserAgent()), - // TODO We should probably just use the status code directly - Tag.of("connectedElsewhere", String.valueOf(code == 4409))); - - Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); - - client.close(code, message); - } - - private record StoredMessageInfo(UUID guid, long serverTimestamp) { - + Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "false")).increment(); + client.close(4401, "Reauthentication required"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java new file mode 100644 index 000000000..4a6608dc5 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.websocket; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.websocket.WebSocketClient; +import org.whispersystems.websocket.session.WebSocketSessionContext; + +class AuthenticatedConnectListenerTest { + + private AccountsManager accountsManager; + private DisconnectionRequestManager disconnectionRequestManager; + + private WebSocketConnection authenticatedWebSocketConnection; + private AuthenticatedConnectListener authenticatedConnectListener; + + private Account authenticatedAccount; + private WebSocketClient webSocketClient; + private WebSocketSessionContext webSocketSessionContext; + + private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); + private static final byte DEVICE_ID = Device.PRIMARY_ID; + + @BeforeEach + void setUpBeforeEach() { + accountsManager = mock(AccountsManager.class); + disconnectionRequestManager = mock(DisconnectionRequestManager.class); + + authenticatedWebSocketConnection = mock(WebSocketConnection.class); + + authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager, + disconnectionRequestManager, + (_, _, _) -> authenticatedWebSocketConnection, + _ -> mock(OpenWebSocketCounter.class)); + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + + authenticatedAccount = mock(Account.class); + when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); + when(authenticatedAccount.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); + + webSocketClient = mock(WebSocketClient.class); + + webSocketSessionContext = mock(WebSocketSessionContext.class); + when(webSocketSessionContext.getClient()).thenReturn(webSocketClient); + } + + @Test + void onWebSocketConnectAuthenticated() { + when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) + .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + + when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount)); + + authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); + + verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); + verify(webSocketSessionContext).addWebsocketClosedListener(any()); + verify(authenticatedWebSocketConnection).start(); + } + + @Test + void onWebSocketConnectAuthenticatedAccountNotFound() { + when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) + .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + + when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.empty()); + + authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); + + verify(webSocketClient).close(eq(1011), anyString()); + + verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); + verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); + verify(authenticatedWebSocketConnection, never()).start(); + } + + @Test + void onWebSocketConnectAuthenticatedStartException() { + when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) + .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); + + when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount)); + doThrow(new RuntimeException()).when(authenticatedWebSocketConnection).start(); + + authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); + + verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); + verify(webSocketSessionContext).addWebsocketClosedListener(any()); + verify(authenticatedWebSocketConnection).start(); + + verify(webSocketClient).close(eq(1011), anyString()); + } + + @Test + void onWebSocketConnectUnauthenticated() { + authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); + + verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); + verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); + verify(authenticatedWebSocketConnection, never()).start(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 5853b9282..028d23b50 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -32,7 +33,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -42,7 +42,6 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; -import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -90,16 +89,18 @@ class WebSocketConnectionIntegrationTest { private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; - private DynamicConfigurationManager dynamicConfigurationManager; - private long serialTimestamp = System.currentTimeMillis(); @BeforeEach void setUp() throws Exception { sharedExecutorService = Executors.newSingleThreadExecutor(); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); - dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + + @SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = + mock(DynamicConfigurationManager.class); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), @@ -115,10 +116,14 @@ class WebSocketConnectionIntegrationTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(Device.PRIMARY_ID); + + redisMessageAvailabilityManager.start(); } @AfterEach void tearDown() throws Exception { + redisMessageAvailabilityManager.stop(); + sharedExecutorService.shutdown(); //noinspection ResultOfMethodCallIgnored sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS); @@ -143,7 +148,8 @@ class WebSocketConnectionIntegrationTest { messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class)); + mock(ExperimentEnrollmentManager.class) + ); final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); @@ -171,36 +177,21 @@ class WebSocketConnectionIntegrationTest { } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - final AtomicBoolean queueCleared = new AtomicBoolean(false); when(successResponse.getStatus()).thenReturn(200); when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) .thenReturn(CompletableFuture.completedFuture(successResponse)); - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( - (Answer>) invocation -> { - synchronized (queueCleared) { - queueCleared.set(true); - queueCleared.notifyAll(); - } + webSocketConnection.start(); - return CompletableFuture.completedFuture(successResponse); - }); + @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = + ArgumentCaptor.forClass(Optional.class); - webSocketConnection.processStoredMessages(); + verify(webSocketClient, timeout(10_000)) + .sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); - synchronized (queueCleared) { - while (!queueCleared.get()) { - queueCleared.wait(); - } - } - - @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass( - Optional.class); - - verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), - eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); - verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)) + .sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); final List sentMessages = new ArrayList<>(); @@ -232,7 +223,8 @@ class WebSocketConnectionIntegrationTest { messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class)); + mock(ExperimentEnrollmentManager.class) + ); final int persistedMessageCount = 207; final int cachedMessageCount = 173; @@ -264,10 +256,10 @@ class WebSocketConnectionIntegrationTest { when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( CompletableFuture.failedFuture(new IOException("Connection closed"))); - webSocketConnection.processStoredMessages(); + webSocketConnection.start(); //noinspection unchecked - ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); + final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); @@ -275,7 +267,7 @@ class WebSocketConnectionIntegrationTest { eq(Optional.empty())); final List sentMessages = messageBodyCaptor.getAllValues().stream() - .map(Optional::get) + .map(Optional::orElseThrow) .map(messageBytes -> { try { return Envelope.parseFrom(messageBytes); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index f9c9d215a..9b8d44963 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -5,54 +5,39 @@ package org.whispersystems.textsecuregcm.websocket; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import com.google.common.net.HttpHeaders; import com.google.protobuf.ByteString; -import com.google.protobuf.InvalidProtocolBufferException; -import io.dropwizard.auth.basic.BasicCredentials; +import io.lettuce.core.RedisCommandTimeoutException; import io.lettuce.core.RedisException; -import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.time.Instant; -import java.util.LinkedList; +import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.Queue; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.stubbing.Answer; -import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.InOrder; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -61,47 +46,43 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; -import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.storage.ConflictingMessageConsumerException; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.MessageStream; +import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import org.whispersystems.websocket.session.WebSocketSessionContext; +import reactor.adapter.JdkFlowAdapter; import reactor.core.publisher.Flux; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; +import reactor.util.retry.Retry; class WebSocketConnectionTest { - private static final String VALID_E164 = "+14152222222"; - private static final UUID VALID_UUID = UUID.randomUUID(); - - private static final int SOURCE_DEVICE_ID = 1; - - private static final String VALID_PASSWORD = "secure"; - - private AccountAuthenticator accountAuthenticator; - private AccountsManager accountsManager; private Account account; private Device device; - private UpgradeRequest upgradeRequest; private MessagesManager messagesManager; private ReceiptSender receiptSender; private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; + private static final int MAX_RETRIES = 3; + private static final Retry RETRY_SPEC = Retry.backoff(MAX_RETRIES, Duration.ofMillis(5)) + .maxBackoff(Duration.ofMillis(20)) + .filter(throwable -> !WebSocketConnection.isConnectionClosedException(throwable)); + + private static final int SOURCE_DEVICE_ID = 1; + @BeforeEach void setup() { - accountAuthenticator = mock(AccountAuthenticator.class); - accountsManager = mock(AccountsManager.class); account = mock(Account.class); device = mock(Device.class); - upgradeRequest = mock(UpgradeRequest.class); messagesManager = mock(MessagesManager.class); receiptSender = mock(ReceiptSender.class); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); @@ -114,515 +95,7 @@ class WebSocketConnectionTest { messageDeliveryScheduler.dispose(); } - @Test - void testCredentials() throws Exception { - WebSocketAccountAuthenticator webSocketAuthenticator = - new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, - new MessageMetrics(Duration.ofDays(30)), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - mock(RedisMessageAvailabilityManager.class), mock(DisconnectionRequestManager.class), - messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class)); - WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); - - when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_E164, VALID_PASSWORD)))) - .thenReturn(Optional.of(new AuthenticatedDevice(VALID_UUID, Device.PRIMARY_ID, Instant.now()))); - - Optional account = webSocketAuthenticator.authenticate(upgradeRequest); - when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null)); - when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null)); - - final WebSocketClient webSocketClient = mock(WebSocketClient.class); - when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); - when(sessionContext.getClient()).thenReturn(webSocketClient); - - // authenticated - valid user - connectListener.onWebSocketConnect(sessionContext); - - verify(sessionContext, times(1)).addWebsocketClosedListener( - any(WebSocketSessionContext.WebSocketEventListener.class)); - - // unauthenticated - when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); - account = webSocketAuthenticator.authenticate(upgradeRequest); - assertFalse(account.isPresent()); - - connectListener.onWebSocketConnect(sessionContext); - verify(sessionContext, times(2)).addWebsocketClosedListener( - any(WebSocketSessionContext.WebSocketEventListener.class)); - - verifyNoMoreInteractions(messagesManager); - } - - @Test - void testOpen() { - - UUID accountUuid = UUID.randomUUID(); - UUID senderOneUuid = UUID.randomUUID(); - UUID senderTwoUuid = UUID.randomUUID(); - - List outgoingMessages = List.of(createMessage(senderOneUuid, accountUuid, 1111, "first"), - createMessage(senderOneUuid, accountUuid, 2222, "second"), - createMessage(senderTwoUuid, accountUuid, 3333, "third")); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - when(account.getNumber()).thenReturn("+14152222222"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - - final Device sender1device = mock(Device.class); - - List sender1devices = List.of(sender1device); - - Account sender1 = mock(Account.class); - when(sender1.getDevices()).thenReturn(sender1devices); - - when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); - when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - - when(messagesManager.delete(any(), any(), any(), any())).thenReturn( - CompletableFuture.completedFuture(Optional.empty())); - - String userAgent = HttpHeaders.USER_AGENT; - - when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenReturn(Flux.fromIterable(outgoingMessages)); - - final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); - - when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), nullable(List.class), any())) - .thenAnswer(invocation -> { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - }); - - WebSocketConnection connection = webSocketConnection(client); - - connection.start(); - verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), - any()); - - assertEquals(3, futures.size()); - - WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); - when(response.getStatus()).thenReturn(200); - futures.get(1).complete(response); - - futures.get(0).completeExceptionally(new IOException()); - futures.get(2).completeExceptionally(new IOException()); - - verify(messagesManager, times(1)).delete(eq(accountUuid), eq(device), - eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); - verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)), - eq(2222L)); - - connection.stop(); - verify(client).close(anyInt(), anyString()); - } - - @Test - public void testOnlineSend() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), argThat(d -> d.getId() == Device.PRIMARY_ID), anyBoolean())) - .thenReturn(Flux.empty()) - .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) - .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second"))) - .thenReturn(Flux.empty()); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final AtomicInteger sendCounter = new AtomicInteger(0); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) - .thenAnswer(invocation -> { - synchronized (sendCounter) { - sendCounter.incrementAndGet(); - sendCounter.notifyAll(); - } - - return CompletableFuture.completedFuture(successResponse); - }); - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - // This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of - // messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded - // future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize - // or wait for anything. - connection.start(); - - connection.handleNewMessageAvailable(); - - synchronized (sendCounter) { - while (sendCounter.get() < 1) { - sendCounter.wait(); - } - } - - connection.handleNewMessageAvailable(); - - synchronized (sendCounter) { - while (sendCounter.get() < 2) { - sendCounter.wait(); - } - } - }); - - verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); - } - - @Test - void testPendingSend() { - final UUID accountUuid = UUID.randomUUID(); - final UUID senderTwoUuid = UUID.randomUUID(); - - final Envelope firstMessage = Envelope.newBuilder() - .setServerGuid(UUID.randomUUID().toString()) - .setSourceServiceId(UUID.randomUUID().toString()) - .setDestinationServiceId(accountUuid.toString()) - .setUpdatedPni(UUID.randomUUID().toString()) - .setClientTimestamp(System.currentTimeMillis()) - .setSourceDevice(1) - .setType(Envelope.Type.CIPHERTEXT) - .build(); - - final Envelope secondMessage = Envelope.newBuilder() - .setServerGuid(UUID.randomUUID().toString()) - .setSourceServiceId(senderTwoUuid.toString()) - .setDestinationServiceId(accountUuid.toString()) - .setClientTimestamp(System.currentTimeMillis()) - .setSourceDevice(2) - .setType(Envelope.Type.CIPHERTEXT) - .build(); - - final List pendingMessages = List.of(firstMessage, secondMessage); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - when(account.getNumber()).thenReturn("+14152222222"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - - final Device sender1device = mock(Device.class); - - List sender1devices = List.of(sender1device); - - Account sender1 = mock(Account.class); - when(sender1.getDevices()).thenReturn(sender1devices); - - when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); - when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - - when(messagesManager.delete(any(), any(), any(), any())).thenReturn( - CompletableFuture.completedFuture(Optional.empty())); - - String userAgent = HttpHeaders.USER_AGENT; - - when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenReturn(Flux.fromIterable(pendingMessages)); - - final List> futures = new LinkedList<>(); - final WebSocketClient client = mock(WebSocketClient.class); - - when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenAnswer((Answer>) invocationOnMock -> { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - }); - - WebSocketConnection connection = webSocketConnection(client); - - connection.start(); - - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); - - assertEquals(futures.size(), 2); - - WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); - when(response.getStatus()).thenReturn(200); - futures.get(1).complete(response); - futures.get(0).completeExceptionally(new IOException()); - - verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), - eq(secondMessage.getClientTimestamp())); - - connection.stop(); - verify(client).close(anyInt(), anyString()); - } - - @Test - void testProcessStoredMessageConcurrency() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - final AtomicBoolean threadWaiting = new AtomicBoolean(false); - final AtomicBoolean returnMessageList = new AtomicBoolean(false); - - when( - messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenAnswer(invocation -> { - synchronized (threadWaiting) { - threadWaiting.set(true); - threadWaiting.notifyAll(); - } - - synchronized (returnMessageList) { - while (!returnMessageList.get()) { - returnMessageList.wait(); - } - } - - return Flux.empty(); - }); - - final Thread[] threads = new Thread[10]; - final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - connection.processStoredMessages(); - unblockedThreadsLatch.countDown(); - }); - - threads[i].start(); - } - - unblockedThreadsLatch.await(); - - synchronized (threadWaiting) { - while (!threadWaiting.get()) { - threadWaiting.wait(); - } - } - - synchronized (returnMessageList) { - returnMessageList.set(true); - returnMessageList.notifyAll(); - } - - for (final Thread thread : threads) { - thread.join(); - } - }); - - verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), any(), eq(false)); - } - - @Test - void testProcessStoredMessagesMultiplePages() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - when(account.getNumber()).thenReturn("+18005551234"); - final UUID accountUuid = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - final List firstPageMessages = - List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), - createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); - - final List secondPageMessages = - List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - - when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) - .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream()))); - - when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) - .thenAnswer(invocation -> CompletableFuture.completedFuture(successResponse)); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) - .thenAnswer(invocation -> { - queueEmptyLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - connection.processStoredMessages(); - queueEmptyLatch.await(); - }); - - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), - eq("/api/v1/message"), any(List.class), any(Optional.class)); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - @Test - void testProcessStoredMessagesMultiplePagesBackpressure() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - when(account.getNumber()).thenReturn("+18005551234"); - final UUID accountUuid = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - // Create two publishers, each with >2x WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY messages - final TestPublisher firstPublisher = TestPublisher.createCold(); - final List firstPublisherMessages = IntStream.range(1, - 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 23) - .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) - .toList(); - - final TestPublisher secondPublisher = TestPublisher.createCold(); - final List secondPublisherMessages = IntStream.range(firstPublisherMessages.size(), - firstPublisherMessages.size() + 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 73) - .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) - .toList(); - - final Flux allMessages = Flux.concat(firstPublisher, secondPublisher); - when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) - .thenReturn(allMessages); - - when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - - final Queue> pendingClientAcks = new LinkedList<>(); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) - .thenAnswer(invocation -> { - final CompletableFuture pendingAck = new CompletableFuture<>(); - pendingClientAcks.add(pendingAck); - return pendingAck; - }); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) - .thenAnswer(invocation -> { - queueEmptyLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - // start processing - connection.processStoredMessages(); - - firstPublisher.assertWasRequested(); - // emit all messages from the first publisher - firstPublisher.emit(firstPublisherMessages.toArray(new Envelope[]{})); - // nothing should be requested from the second publisher, because max concurrency is less than the number emitted, - // and none have completed - secondPublisher.assertWasNotRequested(); - // there should only be MESSAGE_SENDER_MAX_CONCURRENCY pending client acknowledgements - assertEquals(WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY, pendingClientAcks.size()); - - while (!pendingClientAcks.isEmpty()) { - pendingClientAcks.poll().complete(successResponse); - } - - secondPublisher.assertWasRequested(); - secondPublisher.emit(secondPublisherMessages.toArray(new Envelope[0])); - - while (!pendingClientAcks.isEmpty()) { - pendingClientAcks.poll().complete(successResponse); - } - - queueEmptyLatch.await(); - }); - - verify(client, times(firstPublisherMessages.size() + secondPublisherMessages.size())).sendRequest(eq("PUT"), - eq("/api/v1/message"), any(List.class), any(Optional.class)); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - @Test - void testProcessStoredMessagesContainsSenderUuid() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - when(account.getNumber()).thenReturn("+18005551234"); - final UUID accountUuid = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - final UUID senderUuid = UUID.randomUUID(); - final List messages = List.of( - createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); - - when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenReturn(Flux.fromIterable(messages)) - .thenReturn(Flux.empty()); - - when(messagesManager.delete(eq(accountUuid), eq(device), any(UUID.class), any())) - .thenReturn(CompletableFuture.completedFuture(null)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final CountDownLatch queueEmptyLatch = new CountDownLatch(1); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer( - invocation -> CompletableFuture.completedFuture(successResponse)); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) - .thenAnswer(invocation -> { - queueEmptyLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); - - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - connection.processStoredMessages(); - queueEmptyLatch.await(); - }); - - verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), - argThat(argument -> { - if (argument.isEmpty()) { - return false; - } - - final byte[] body = argument.get(); - try { - final Envelope envelope = Envelope.parseFrom(body); - if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) { - return false; - } - return envelope.getSourceServiceId().equals(senderUuid.toString()); - } catch (InvalidProtocolBufferException e) { - return false; - } - })); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - private WebSocketConnection webSocketConnection(final WebSocketClient client) { + private WebSocketConnection buildWebSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(Duration.ofDays(30)), @@ -634,191 +107,337 @@ class WebSocketConnectionTest { Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class)); + mock(ExperimentEnrollmentManager.class), + RETRY_SPEC); } @Test - void testProcessStoredMessagesSingleEmptyCall() { + void testSendMessages() { + + final UUID destinationAccountIdentifier = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + final Envelope successfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Success"); + final Envelope secondSuccessfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "Second success"); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( + new MessageStreamEntry.Envelope(successfulMessage), + new MessageStreamEntry.QueueEmpty(), + new MessageStreamEntry.Envelope(secondSuccessfulMessage)))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) + .thenReturn(messageStream); + + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); - connection.processStoredMessages(); + when(client.isOpen()).thenReturn(true); - verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenReturn(CompletableFuture.completedFuture(successResponse)); + + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); + + verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage)))); + + verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage)))); + + verify(messageStream).acknowledgeMessage(successfulMessage); + verify(messageStream).acknowledgeMessage(secondSuccessfulMessage); + + verify(receiptSender) + .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), + deviceId, + AciServiceIdentifier.valueOf(successfulMessage.getSourceServiceId()), + successfulMessage.getClientTimestamp()); + + verify(receiptSender) + .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), + deviceId, + AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()), + secondSuccessfulMessage.getClientTimestamp()); + + webSocketConnection.stop(); + + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + verify(client).close(eq(1000), anyString()); } @Test - public void testRequeryOnStateMismatch() { + void testSendMessagesWithError() { + + final UUID destinationAccountIdentifier = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + final Envelope successfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Success"); + final Envelope failedMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "Failed"); + final Envelope secondSuccessfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 3, "Second success"); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( + new MessageStreamEntry.Envelope(successfulMessage), + new MessageStreamEntry.Envelope(failedMessage), + new MessageStreamEntry.QueueEmpty(), + new MessageStreamEntry.Envelope(secondSuccessfulMessage)))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) + .thenReturn(messageStream); + + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - final List firstPageMessages = - List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), - createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); - - final List secondPageMessages = - List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(Flux.fromIterable(firstPageMessages)) - .thenReturn(Flux.fromIterable(secondPageMessages)) - .thenReturn(Flux.empty()); - - when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + when(client.isOpen()).thenReturn(true); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) - .thenAnswer(invocation -> { - connection.handleNewMessageAvailable(); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenReturn(CompletableFuture.completedFuture(successResponse)); - return CompletableFuture.completedFuture(successResponse); - }); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage))))) + .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())); - when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) - .thenAnswer(invocation -> { - queueEmptyLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); - assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { - connection.processStoredMessages(); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage)))); - queueEmptyLatch.await(); - }); + verify(client, timeout(500).times(MAX_RETRIES + 1)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage)))); - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), - eq("/api/v1/message"), any(List.class), any(Optional.class)); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage)))); + + verify(messageStream).acknowledgeMessage(successfulMessage); + verify(messageStream).acknowledgeMessage(secondSuccessfulMessage); + + verify(receiptSender) + .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), + deviceId, + AciServiceIdentifier.valueOf(successfulMessage.getSourceServiceId()), + successfulMessage.getClientTimestamp()); + + verify(receiptSender, never()) + .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), + deviceId, + AciServiceIdentifier.valueOf(failedMessage.getSourceServiceId()), + failedMessage.getClientTimestamp()); + + verify(receiptSender) + .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), + deviceId, + AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()), + secondSuccessfulMessage.getClientTimestamp()); + + verify(client, timeout(500)).close(eq(1011), anyString()); + verify(client, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); } @Test - void testProcessCachedMessagesOnly() { + void testQueueEmptySignalOrder() { + + final UUID destinationAccountIdentifier = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + final Envelope initialMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Initial message"); + final Envelope afterQueueDrainMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "After queue drained"); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( + new MessageStreamEntry.Envelope(initialMessage), + new MessageStreamEntry.QueueEmpty(), + new MessageStreamEntry.Envelope(afterQueueDrainMessage)))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) + .thenReturn(messageStream); + + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenAnswer(_ -> CompletableFuture.supplyAsync(() -> successResponse, + CompletableFuture.delayedExecutor(100, TimeUnit.MILLISECONDS))); - verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); - connection.handleNewMessageAvailable(); + final InOrder inOrder = inOrder(client, messageStream); - verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true); + // Sending the initial message will succeed after a delay, at which point we'll acknowledge the message. Make sure + // we wait for that process to complete before sending the "queue empty" signal + inOrder.verify(messageStream, timeout(1_000)).acknowledgeMessage(initialMessage); + inOrder.verify(client, timeout(1_000)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + + webSocketConnection.stop(); + verify(client).close(eq(1000), anyString()); } @Test - void testProcessDatabaseMessagesAfterPersist() { + void testConflictingConsumerSignalOrder() { + + final UUID destinationAccountIdentifier = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + final Envelope message = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Initial message"); + + final MessageStream messageStream = mock(MessageStream.class); + final TestPublisher testPublisher = TestPublisher.createCold(); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(testPublisher)); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) + .thenReturn(messageStream); + + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); - connection.handleMessagesPersisted(); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenReturn(new CompletableFuture<>()); - verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); + + testPublisher.next(new MessageStreamEntry.Envelope(message)); + testPublisher.error(new ConflictingMessageConsumerException()); + + final InOrder inOrder = inOrder(client, messageStream); + + // A "conflicting consumer" should close the socket as soon as possible (i.e. even if messages are still getting + // processed) + inOrder.verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> + body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(message)))); + + verify(client).close(eq(4409), anyString()); } @Test - void testRetrieveMessageException() { - UUID accountUuid = UUID.randomUUID(); + void testSendMessagesEmptyQueue() { + final UUID accountUuid = UUID.randomUUID(); + + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just(new MessageStreamEntry.QueueEmpty()))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final WebSocketClient client = mock(WebSocketClient.class); + when(client.isOpen()).thenReturn(true); + + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + + webSocketConnection.start(); + + verify(client, timeout(1_000)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + } + + @Test + void testSendMessagesConflictingConsumer() { + final UUID accountUuid = UUID.randomUUID(); + + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.error(new ConflictingMessageConsumerException()))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + + final WebSocketClient client = mock(WebSocketClient.class); + when(client.isOpen()).thenReturn(true); + + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); + + verify(client, timeout(1_000)).close(eq(4409), anyString()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testSendMessagesRetrievalException(final boolean clientOpen) { + final UUID accountUuid = UUID.randomUUID(); when(device.getId()).thenReturn((byte) 2); - - when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenReturn(Flux.error(new RedisException("OH NO"))); + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.error(new RedisException("OH NO")))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); final WebSocketClient client = mock(WebSocketClient.class); - when(client.isOpen()).thenReturn(true); + when(client.isOpen()).thenReturn(clientOpen); - WebSocketConnection connection = webSocketConnection(client); - connection.start(); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + webSocketConnection.start(); - verify(client).close(eq(1011), anyString()); - } - - @Test - void testRetrieveMessageExceptionClientDisconnected() { - UUID accountUuid = UUID.randomUUID(); - - when(device.getId()).thenReturn((byte) 2); - - when(account.getNumber()).thenReturn("+14152222222"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - - when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) - .thenReturn(Flux.error(new RedisException("OH NO"))); - - final WebSocketClient client = mock(WebSocketClient.class); - when(client.isOpen()).thenReturn(false); - - WebSocketConnection connection = webSocketConnection(client); - connection.start(); - - verify(client, never()).close(anyInt(), anyString()); + if (clientOpen) { + verify(client).close(eq(1011), anyString()); + } else { + verify(client, never()).close(anyInt(), any()); + } } @Test @@ -828,28 +447,31 @@ class WebSocketConnectionTest { final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); - when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final int totalMessages = 1000; - final TestPublisher testPublisher = TestPublisher.createCold(); - final Flux flux = Flux.from(testPublisher); + final TestPublisher testPublisher = TestPublisher.createCold(); + final Flux flux = Flux.from(testPublisher); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(flux); + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(flux)); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), any(), any(), any())).thenReturn( - CompletableFuture.completedFuture(Optional.empty())); - WebSocketConnection connection = webSocketConnection(client); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - connection.start(); + webSocketConnection.start(); StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); @@ -858,7 +480,7 @@ class WebSocketConnectionTest { .thenRequest(totalMessages * 2) .then(() -> { for (long i = 0; i < totalMessages; i++) { - testPublisher.next(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i)); + testPublisher.next(new MessageStreamEntry.Envelope(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i))); } testPublisher.complete(); }) @@ -877,40 +499,46 @@ class WebSocketConnectionTest { final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); - when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final AtomicBoolean canceled = new AtomicBoolean(); - final Flux flux = Flux.create(s -> { + final Flux flux = Flux.create(s -> { s.onRequest(n -> { // the subscriber should request more than 1 message, but we will only send one, so that // we are sure the subscriber is waiting for more when we stop the connection assert n > 1; - s.next(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")); + s.next(new MessageStreamEntry.Envelope(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))); }); s.onCancel(() -> canceled.set(true)); }); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) - .thenReturn(flux); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(flux)); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), any(), any(), any())).thenReturn( - CompletableFuture.completedFuture(Optional.empty())); - WebSocketConnection connection = webSocketConnection(client); + final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - connection.start(); + webSocketConnection.start(); verify(client).sendRequest(any(), any(), any(), any()); // close the connection before the publisher completes - connection.stop(); + webSocketConnection.stop(); StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); @@ -924,7 +552,61 @@ class WebSocketConnectionTest { .verify(); } - private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long timestamp, String content) { + @Test + void testRetryOnError() { + final UUID accountIdentifier = UUID.randomUUID(); + + final List outgoingMessages = List.of(createMessage(accountIdentifier, accountIdentifier, 1111, "first")); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + + final MessageStream messageStream = mock(MessageStream.class); + + when(messageStream.getMessages()) + .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(outgoingMessages) + .map(MessageStreamEntry.Envelope::new))); + + when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) + .thenReturn(messageStream); + + when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); + + final WebSocketClient client = mock(WebSocketClient.class); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) + .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) + .thenReturn(CompletableFuture.completedFuture(successResponse)); + + final WebSocketConnection connection = buildWebSocketConnection(client); + + connection.start(); + + verify(client, timeout(500).times(3)) + .sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); + + verify(messageStream, timeout(500)).acknowledgeMessage(outgoingMessages.getFirst()); + + verify(receiptSender, timeout(500)) + .sendReceipt(new AciServiceIdentifier(accountIdentifier), deviceId, new AciServiceIdentifier(accountIdentifier), 1111L); + + connection.stop(); + verify(client).close(eq(1000), anyString()); + } + + private static Envelope createMessage(final UUID senderUuid, + final UUID destinationUuid, + final long timestamp, + final String content) { + return Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) .setType(Envelope.Type.CIPHERTEXT)