diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2e1495810..2155d6f12 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -994,9 +994,8 @@ public class WhisperServerService extends Application new WebSocketConnection(receiptSender, - messagesManager, - messageMetrics, - pushNotificationManager, - pushNotificationScheduler, - account, - device, - client, - messageDeliveryScheduler, - clientReleaseManager, - messageDeliveryLoopMonitor, - experimentEnrollmentManager), - authenticated -> new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, - NEW_CONNECTION_COUNTER_NAME, - CONNECTED_DURATION_TIMER_NAME, - Duration.ofHours(3), - Tags.of(AUTHENTICATED_TAG_NAME, String.valueOf(authenticated))) - ); - } - - @VisibleForTesting AuthenticatedConnectListener( - final AccountsManager accountsManager, - final DisconnectionRequestManager disconnectionRequestManager, - final WebSocketConnectionBuilder webSocketConnectionBuilder, - final Function openWebSocketCounterBuilder) { - this.accountsManager = accountsManager; + this.receiptSender = receiptSender; + this.messagesManager = messagesManager; + this.messageMetrics = messageMetrics; + this.pushNotificationManager = pushNotificationManager; + this.pushNotificationScheduler = pushNotificationScheduler; + this.redisMessageAvailabilityManager = redisMessageAvailabilityManager; this.disconnectionRequestManager = disconnectionRequestManager; - this.webSocketConnectionBuilder = webSocketConnectionBuilder; + this.messageDeliveryScheduler = messageDeliveryScheduler; + this.clientReleaseManager = clientReleaseManager; + this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; + this.experimentEnrollmentManager = experimentEnrollmentManager; - openAuthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(true); - openUnauthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(false); + openAuthenticatedWebSocketCounter = + new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "true")); + + openUnauthenticatedWebSocketCounter = + new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, NEW_CONNECTION_COUNTER_NAME, CONNECTED_DURATION_TIMER_NAME, Duration.ofHours(3), Tags.of(AUTHENTICATED_TAG_NAME, "false")); } @Override public void onWebSocketConnect(final WebSocketSessionContext context) { final boolean authenticated = (context.getAuthenticated() != null); + final OpenWebSocketCounter openWebSocketCounter = + authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter; - (authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter).countOpenWebSocket(context); + openWebSocketCounter.countOpenWebSocket(context); if (authenticated) { final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); - final Optional maybeAuthenticatedAccount = - accountsManager.getByAccountIdentifier(auth.accountIdentifier()); - - final Optional maybeAuthenticatedDevice = - maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); + final Optional maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.accountIdentifier()); + final Optional maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.deviceId())); if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { log.warn("{}:{} not found when opening authenticated WebSocket", auth.accountIdentifier(), auth.deviceId()); @@ -130,10 +116,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { return; } - final WebSocketConnection connection = - webSocketConnectionBuilder.buildWebSocketConnection(maybeAuthenticatedAccount.get(), - maybeAuthenticatedDevice.get(), - context.getClient()); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, + messagesManager, + messageMetrics, + pushNotificationManager, + pushNotificationScheduler, + maybeAuthenticatedAccount.get(), + maybeAuthenticatedDevice.get(), + context.getClient(), + messageDeliveryScheduler, + clientReleaseManager, + messageDeliveryLoopMonitor, + experimentEnrollmentManager); disconnectionRequestManager.addListener(maybeAuthenticatedAccount.get().getIdentifier(IdentityType.ACI), maybeAuthenticatedDevice.get().getId(), @@ -144,11 +138,27 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { maybeAuthenticatedDevice.get().getId(), connection); + // We begin the shutdown process by removing this client's "presence," which means it will again begin to + // receive push notifications for inbound messages. We should do this first because, at this point, the + // connection has already closed and attempts to actually deliver a message via the connection will not succeed. + // It's preferable to start sending push notifications as soon as possible. + redisMessageAvailabilityManager.handleClientDisconnected(auth.accountIdentifier(), auth.deviceId()); + + // Finally, stop trying to deliver messages and send a push notification if the connection is aware of any + // undelivered messages. connection.stop(); }); try { + // Once we "start" the websocket connection, we'll cancel any scheduled "you may have new messages" push + // notifications and begin delivering any stored messages for the connected device. We have not yet declared the + // client as "present" yet. If a message arrives at this point, we will update the message availability state + // correctly, but we may also send a spurious push notification. connection.start(); + + // Finally, we register this client's presence, which suppresses push notifications. We do this last because + // receiving extra push notifications is generally preferable to missing out on a push notification. + redisMessageAvailabilityManager.handleClientConnected(auth.accountIdentifier(), auth.deviceId(), connection); } catch (final Exception e) { log.warn("Failed to initialize websocket", e); context.getClient().close(1011, "Unexpected error initializing connection"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index fded9d827..e4482123e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -14,17 +14,22 @@ import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; import java.time.Duration; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import org.apache.commons.lang3.StringUtils; import org.eclipse.jetty.util.StaticException; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestListener; @@ -37,28 +42,26 @@ import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.push.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.storage.ConflictingMessageConsumerException; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.MessageStream; -import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import reactor.adapter.JdkFlowAdapter; import reactor.core.Disposable; import reactor.core.observability.micrometer.Micrometer; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.util.retry.Retry; -public class WebSocketConnection implements DisconnectionRequestListener { +public class WebSocketConnection implements MessageAvailabilityListener, DisconnectionRequestListener { private static final Counter sendMessageCounter = Metrics.counter(name(WebSocketConnection.class, "sendMessage")); private static final Counter bytesSentCounter = Metrics.counter(name(WebSocketConnection.class, "bytesSent")); @@ -75,15 +78,17 @@ public class WebSocketConnection implements DisconnectionRequestListener { "sendMessages"); private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, "sendMessageError"); + private static final String MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, "messagesAvailable"); + private static final String MESSAGES_PERSISTED_COUNTER_NAME = name(WebSocketConnection.class, "messagesPersisted"); private static final String SEND_MESSAGE_DURATION_TIMER_NAME = name(WebSocketConnection.class, "sendMessageDuration"); + private static final String PRESENCE_MANAGER_TAG = "presenceManager"; private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; private static final String ERROR_TYPE_TAG = "errorType"; private static final String EXCEPTION_TYPE_TAG = "exceptionType"; - private static final String CONNECTED_ELSEWHERE_TAG = "connectedElsewhere"; - private static final Duration SLOW_DRAIN_THRESHOLD = Duration.ofSeconds(10); + private static final long SLOW_DRAIN_THRESHOLD = 10_000; @VisibleForTesting static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100; @@ -103,21 +108,28 @@ public class WebSocketConnection implements DisconnectionRequestListener { private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final ExperimentEnrollmentManager experimentEnrollmentManager; - private final Retry retrySpec; - private final Account authenticatedAccount; private final Device authenticatedDevice; - private final MessageStream messageStream; private final WebSocketClient client; - private final Tags platformTag; + private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); + private final AtomicReference storedMessageState = new AtomicReference<>( + StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); private final LongAdder sentMessageCounter = new LongAdder(); + private final AtomicLong queueDrainStartNanoTime = new AtomicLong(); private final AtomicReference messageSubscription = new AtomicReference<>(); private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; + private enum StoredMessageState { + EMPTY, + CACHED_NEW_MESSAGES_AVAILABLE, + PERSISTED_NEW_MESSAGES_AVAILABLE + } + public WebSocketConnection(final ReceiptSender receiptSender, final MessagesManager messagesManager, final MessageMetrics messageMetrics, @@ -131,38 +143,6 @@ public class WebSocketConnection implements DisconnectionRequestListener { final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { - this(receiptSender, - messagesManager, - messageMetrics, - pushNotificationManager, - pushNotificationScheduler, - authenticatedAccount, - authenticatedDevice, - client, - messageDeliveryScheduler, - clientReleaseManager, - messageDeliveryLoopMonitor, - experimentEnrollmentManager, - Retry.backoff(4, Duration.ofSeconds(1)) - .maxBackoff(Duration.ofSeconds(2)) - .filter(throwable -> !isConnectionClosedException(throwable))); - } - - @VisibleForTesting - WebSocketConnection(final ReceiptSender receiptSender, - final MessagesManager messagesManager, - final MessageMetrics messageMetrics, - final PushNotificationManager pushNotificationManager, - final PushNotificationScheduler pushNotificationScheduler, - final Account authenticatedAccount, - final Device authenticatedDevice, - final WebSocketClient client, - final Scheduler messageDeliveryScheduler, - final ClientReleaseManager clientReleaseManager, - final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, - final ExperimentEnrollmentManager experimentEnrollmentManager, - final Retry retrySpec) { - this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; @@ -175,79 +155,12 @@ public class WebSocketConnection implements DisconnectionRequestListener { this.clientReleaseManager = clientReleaseManager; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.experimentEnrollmentManager = experimentEnrollmentManager; - - this.retrySpec = retrySpec; - - this.messageStream = - messagesManager.getMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice); - - this.platformTag = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); } public void start() { pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent()); - - final long queueDrainStartNanos = System.nanoTime(); - final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); - - messageSubscription.set(JdkFlowAdapter.flowPublisherToFlux(messageStream.getMessages()) - .name(SEND_MESSAGES_FLUX_NAME) - .tap(Micrometer.metrics(Metrics.globalRegistry)) - .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) - // We want to handle conflicting connections as soon as possible, and so do this before we start processing - // messages in the `flatMapSequential` stage below. If we didn't do this first, then we'd wait for clients to - // process messages before sending the "connected elsewhere" signal, and while that's ultimately not harmful, - // it's also not ideal. - .doOnError(ConflictingMessageConsumerException.class, _ -> { - Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "true")).increment(); - client.close(4409, "Connected elsewhere"); - }) - .doOnNext(entry -> { - if (entry instanceof MessageStreamEntry.Envelope(final Envelope message)) { - if (hasSentFirstMessage.compareAndSet(false, true)) { - messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI), - authenticatedDevice.getId(), - UUID.fromString(message.getServerGuid()), - client.getUserAgent(), - "websocket"); - } - } - }) - .flatMapSequential(entry -> switch (entry) { - case MessageStreamEntry.Envelope envelope -> Mono.fromFuture(() -> sendMessage(envelope.message())) - .retryWhen(retrySpec) - .thenReturn(entry); - case MessageStreamEntry.QueueEmpty _ -> Mono.just(entry); - }, MESSAGE_SENDER_MAX_CONCURRENCY) - // `ConflictingMessageConsumerException` is handled before processing messages - .doOnError(throwable -> !(throwable instanceof ConflictingMessageConsumerException), throwable -> { - measureSendMessageErrors(throwable); - - if (!client.isOpen()) { - logger.debug("Client disconnected before queue cleared"); - return; - } - - client.close(1011, "Failed to retrieve messages"); - }) - // Make sure we process message acknowledgements before sending the "queue clear" signal - .doOnNext(entry -> { - if (entry instanceof MessageStreamEntry.QueueEmpty) { - final Duration drainDuration = Duration.ofNanos(System.nanoTime() - queueDrainStartNanos); - - Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, platformTag).record(sentMessageCounter.sum()); - Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, platformTag).record(drainDuration); - - if (drainDuration.compareTo(SLOW_DRAIN_THRESHOLD) > 0) { - Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, platformTag).increment(); - } - - client.sendRequest("PUT", "/api/v1/queue/empty", - Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); - } - }) - .subscribeOn(messageDeliveryScheduler) - .subscribe()); + queueDrainStartNanoTime.set(System.nanoTime()); + processStoredMessages(); } public void stop() { @@ -258,22 +171,16 @@ public class WebSocketConnection implements DisconnectionRequestListener { client.close(1000, "OK"); - messagesManager.mayHaveMessages(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice) - .thenAccept(mayHaveMessages -> { - if (mayHaveMessages) { - pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount, - authenticatedDevice, - CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); - } - }); + if (storedMessageState.get() != StoredMessageState.EMPTY) { + pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount, + authenticatedDevice, + CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); + } } - private CompletableFuture sendMessage(final Envelope message) { - if (message.getStory() && !client.shouldDeliverStories()) { - return messageStream.acknowledgeMessage(message); - } - - final Optional body = Optional.of(serializeMessage(message)); + private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { + // clear ephemeral field from the envelope + final Optional body = Optional.ofNullable(message.toBuilder().clearEphemeral().build().toByteArray()); sendMessageCounter.increment(); sentMessageCounter.increment(); @@ -301,39 +208,40 @@ public class WebSocketConnection implements DisconnectionRequestListener { final CompletableFuture result; if (isSuccessResponse(response)) { - result = messageStream.acknowledgeMessage(message); + result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, + storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) + .thenApply(ignored -> null); if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { sendDeliveryReceiptFor(message); } } else { - Tags tags = platformTag.and(STATUS_CODE_TAG, String.valueOf(response.getStatus())); + final List tags = new ArrayList<>( + List.of( + Tag.of(STATUS_CODE_TAG, String.valueOf(response.getStatus())), + UserAgentTagUtil.getPlatformTag(client.getUserAgent()) + )); - // TODO Remove this once we've identified the cause of message rejections from desktop clients - if (StringUtils.isNotBlank(response.getMessage())) { - tags = tags.and(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); + // TODO Remove this once we've identified the cause of message rejections from desktop clients + if (StringUtils.isNotBlank(response.getMessage())) { + tags.add(Tag.of(STATUS_MESSAGE_TAG, response.getMessage())); + } + + Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); + + result = CompletableFuture.completedFuture(null); } - Metrics.counter(NON_SUCCESS_RESPONSE_COUNTER_NAME, tags).increment(); - - result = CompletableFuture.completedFuture(null); - } - return result; }) .thenRun(() -> sample.stop(Timer.builder(SEND_MESSAGE_DURATION_TIMER_NAME) .publishPercentileHistogram(true) .minimumExpectedValue(Duration.ofMillis(100)) .maximumExpectedValue(Duration.ofDays(1)) - .tags(platformTag) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))) .register(Metrics.globalRegistry))); } - @VisibleForTesting - static byte[] serializeMessage(final Envelope message) { - return message.toBuilder().clearEphemeral().build().toByteArray(); - } - private void sendDeliveryReceiptFor(Envelope message) { if (!message.hasSourceServiceId()) { return; @@ -343,17 +251,110 @@ public class WebSocketConnection implements DisconnectionRequestListener { receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()), authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()), message.getClientTimestamp()); - } catch (final IllegalArgumentException e) { + } catch (IllegalArgumentException e) { logger.error("Could not parse UUID: {}", message.getSourceServiceId()); - } catch (final Exception e) { + } catch (Exception e) { logger.warn("Failed to send receipt", e); } } - private static boolean isSuccessResponse(final WebSocketResponseMessage response) { + private boolean isSuccessResponse(WebSocketResponseMessage response) { return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } + @VisibleForTesting + void processStoredMessages() { + if (processStoredMessagesSemaphore.tryAcquire()) { + final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); + final boolean cachedMessagesOnly = state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE; + sendMessages(cachedMessagesOnly) + // Update our state with the outcome, send the empty queue message if we need to, and release the semaphore + .whenComplete((ignored, cause) -> { + try { + if (cause != null) { + // We failed, if the state is currently EMPTY, set it to what it was before we tried + storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); + return; + } + + // Cleared the queue! Send a queue empty message if we need to + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); + final long drainDurationNanos = System.nanoTime() - queueDrainStartNanoTime.get(); + + Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); + Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDurationNanos, TimeUnit.NANOSECONDS); + + if (drainDurationNanos > SLOW_DRAIN_THRESHOLD) { + Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); + } + + client.sendRequest("PUT", "/api/v1/queue/empty", + Collections.singletonList(HeaderUtils.getTimestampHeader()), Optional.empty()); + } + } finally { + processStoredMessagesSemaphore.release(); + } + }) + // Potentially kick off more work, must happen after we release the semaphore + .whenComplete((ignored, cause) -> { + if (cause != null) { + if (!client.isOpen()) { + logger.debug("Client disconnected before queue cleared"); + return; + } + + client.close(1011, "Failed to retrieve messages"); + return; + } + + // Success, but check if more messages came in while we were processing + if (storedMessageState.get() != StoredMessageState.EMPTY) { + processStoredMessages(); + } + }); + } + } + + private CompletableFuture sendMessages(final boolean cachedMessagesOnly) { + final CompletableFuture queueCleared = new CompletableFuture<>(); + + final Publisher messages = + messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly); + + final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); + + final Disposable subscription = Flux.from(messages) + .name(SEND_MESSAGES_FLUX_NAME) + .tap(Micrometer.metrics(Metrics.globalRegistry)) + .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) + .doOnNext(envelope -> { + if (hasSentFirstMessage.compareAndSet(false, true)) { + messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI), + authenticatedDevice.getId(), + UUID.fromString(envelope.getServerGuid()), + client.getUserAgent(), + "websocket"); + } + }) + .flatMapSequential(envelope -> Mono.fromFuture(() -> sendMessage(envelope)) + .retryWhen(Retry.backoff(4, Duration.ofSeconds(1)).filter(throwable -> !isConnectionClosedException(throwable))), + MESSAGE_SENDER_MAX_CONCURRENCY) + .doOnError(this::measureSendMessageErrors) + .subscribeOn(messageDeliveryScheduler) + .subscribe( + // no additional consumer of values - it is Flux by now + null, + // this first error will terminate the stream, but we may get multiple errors from in-flight messages + queueCleared::completeExceptionally, + // completion + () -> queueCleared.complete(null) + ); + + messageSubscription.set(subscription); + return queueCleared; + } + private void measureSendMessageErrors(final Throwable e) { final String errorType; @@ -366,22 +367,76 @@ public class WebSocketConnection implements DisconnectionRequestListener { errorType = "other"; } - Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, - platformTag.and(ERROR_TYPE_TAG, errorType, EXCEPTION_TYPE_TAG, e.getClass().getSimpleName())) + Metrics.counter(SEND_MESSAGE_ERROR_COUNTER, Tags.of( + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + Tag.of(ERROR_TYPE_TAG, errorType), + Tag.of(EXCEPTION_TYPE_TAG, e.getClass().getSimpleName()))) .increment(); } - @VisibleForTesting - static boolean isConnectionClosedException(final Throwable throwable) { + private static boolean isConnectionClosedException(final Throwable throwable) { return throwable instanceof java.nio.channels.ClosedChannelException || throwable == WebSocketResourceProvider.CONNECTION_CLOSED_EXCEPTION || throwable instanceof org.eclipse.jetty.io.EofException || (throwable instanceof StaticException staticException && "Closed".equals(staticException.getMessage())); } + private CompletableFuture sendMessage(Envelope envelope) { + final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); + + if (envelope.getStory() && !client.shouldDeliverStories()) { + messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp()); + + return CompletableFuture.completedFuture(null); + } else { + return sendMessage(envelope, new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())); + } + } + + @Override + public void handleNewMessageAvailable() { + Metrics.counter(MESSAGE_AVAILABLE_COUNTER_NAME, + PRESENCE_MANAGER_TAG, "pubsub") + .increment(); + + storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); + + processStoredMessages(); + } + + @Override + public void handleMessagesPersisted() { + Metrics.counter(MESSAGES_PERSISTED_COUNTER_NAME, + PRESENCE_MANAGER_TAG, "pubsub") + .increment(); + + storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + + processStoredMessages(); + } + + @Override + public void handleConflictingMessageConsumer() { + closeConnection(4409, "Connected elsewhere"); + } + @Override public void handleDisconnectionRequest() { - Metrics.counter(DISPLACEMENT_COUNTER_NAME, platformTag.and(CONNECTED_ELSEWHERE_TAG, "false")).increment(); - client.close(4401, "Reauthentication required"); + closeConnection(4401, "Reauthentication required"); + } + + private void closeConnection(final int code, final String message) { + final Tags tags = Tags.of( + UserAgentTagUtil.getPlatformTag(client.getUserAgent()), + // TODO We should probably just use the status code directly + Tag.of("connectedElsewhere", String.valueOf(code == 4409))); + + Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment(); + + client.close(code, message); + } + + private record StoredMessageInfo(UUID guid, long serverTimestamp) { + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java deleted file mode 100644 index 4a6608dc5..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.websocket; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.time.Instant; -import java.util.Optional; -import java.util.UUID; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; -import org.whispersystems.textsecuregcm.identity.IdentityType; -import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.websocket.WebSocketClient; -import org.whispersystems.websocket.session.WebSocketSessionContext; - -class AuthenticatedConnectListenerTest { - - private AccountsManager accountsManager; - private DisconnectionRequestManager disconnectionRequestManager; - - private WebSocketConnection authenticatedWebSocketConnection; - private AuthenticatedConnectListener authenticatedConnectListener; - - private Account authenticatedAccount; - private WebSocketClient webSocketClient; - private WebSocketSessionContext webSocketSessionContext; - - private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); - private static final byte DEVICE_ID = Device.PRIMARY_ID; - - @BeforeEach - void setUpBeforeEach() { - accountsManager = mock(AccountsManager.class); - disconnectionRequestManager = mock(DisconnectionRequestManager.class); - - authenticatedWebSocketConnection = mock(WebSocketConnection.class); - - authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager, - disconnectionRequestManager, - (_, _, _) -> authenticatedWebSocketConnection, - _ -> mock(OpenWebSocketCounter.class)); - - final Device device = mock(Device.class); - when(device.getId()).thenReturn(DEVICE_ID); - - authenticatedAccount = mock(Account.class); - when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTIFIER); - when(authenticatedAccount.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); - - webSocketClient = mock(WebSocketClient.class); - - webSocketSessionContext = mock(WebSocketSessionContext.class); - when(webSocketSessionContext.getClient()).thenReturn(webSocketClient); - } - - @Test - void onWebSocketConnectAuthenticated() { - when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) - .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - - when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount)); - - authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - - verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); - verify(webSocketSessionContext).addWebsocketClosedListener(any()); - verify(authenticatedWebSocketConnection).start(); - } - - @Test - void onWebSocketConnectAuthenticatedAccountNotFound() { - when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) - .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - - when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.empty()); - - authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - - verify(webSocketClient).close(eq(1011), anyString()); - - verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); - verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); - verify(authenticatedWebSocketConnection, never()).start(); - } - - @Test - void onWebSocketConnectAuthenticatedStartException() { - when(webSocketSessionContext.getAuthenticated()).thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - when(webSocketSessionContext.getAuthenticated(AuthenticatedDevice.class)) - .thenReturn(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID, Instant.now())); - - when(accountsManager.getByAccountIdentifier(ACCOUNT_IDENTIFIER)).thenReturn(Optional.of(authenticatedAccount)); - doThrow(new RuntimeException()).when(authenticatedWebSocketConnection).start(); - - authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - - verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); - verify(webSocketSessionContext).addWebsocketClosedListener(any()); - verify(authenticatedWebSocketConnection).start(); - - verify(webSocketClient).close(eq(1011), anyString()); - } - - @Test - void onWebSocketConnectUnauthenticated() { - authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); - - verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); - verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); - verify(authenticatedWebSocketConnection, never()).start(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 028d23b50..5853b9282 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -15,7 +15,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -33,6 +32,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -42,6 +42,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -89,18 +90,16 @@ class WebSocketConnectionIntegrationTest { private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; + private DynamicConfigurationManager dynamicConfigurationManager; + private long serialTimestamp = System.currentTimeMillis(); @BeforeEach void setUp() throws Exception { sharedExecutorService = Executors.newSingleThreadExecutor(); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); - - @SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = - mock(DynamicConfigurationManager.class); - + dynamicConfigurationManager = mock(DynamicConfigurationManager.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); - messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messageDeliveryScheduler, sharedExecutorService, Clock.systemUTC()); messagesDynamoDb = new MessagesDynamoDb(DYNAMO_DB_EXTENSION.getDynamoDbClient(), @@ -116,14 +115,10 @@ class WebSocketConnectionIntegrationTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(Device.PRIMARY_ID); - - redisMessageAvailabilityManager.start(); } @AfterEach void tearDown() throws Exception { - redisMessageAvailabilityManager.stop(); - sharedExecutorService.shutdown(); //noinspection ResultOfMethodCallIgnored sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS); @@ -148,8 +143,7 @@ class WebSocketConnectionIntegrationTest { messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class) - ); + mock(ExperimentEnrollmentManager.class)); final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); @@ -177,21 +171,36 @@ class WebSocketConnectionIntegrationTest { } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + final AtomicBoolean queueCleared = new AtomicBoolean(false); when(successResponse.getStatus()).thenReturn(200); when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) .thenReturn(CompletableFuture.completedFuture(successResponse)); - webSocketConnection.start(); + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( + (Answer>) invocation -> { + synchronized (queueCleared) { + queueCleared.set(true); + queueCleared.notifyAll(); + } - @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = - ArgumentCaptor.forClass(Optional.class); + return CompletableFuture.completedFuture(successResponse); + }); - verify(webSocketClient, timeout(10_000)) - .sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + webSocketConnection.processStoredMessages(); - verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)) - .sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + synchronized (queueCleared) { + while (!queueCleared.get()) { + queueCleared.wait(); + } + } + + @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass( + Optional.class); + + verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), + eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); final List sentMessages = new ArrayList<>(); @@ -223,8 +232,7 @@ class WebSocketConnectionIntegrationTest { messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class) - ); + mock(ExperimentEnrollmentManager.class)); final int persistedMessageCount = 207; final int cachedMessageCount = 173; @@ -256,10 +264,10 @@ class WebSocketConnectionIntegrationTest { when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( CompletableFuture.failedFuture(new IOException("Connection closed"))); - webSocketConnection.start(); + webSocketConnection.processStoredMessages(); //noinspection unchecked - final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); + ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); @@ -267,7 +275,7 @@ class WebSocketConnectionIntegrationTest { eq(Optional.empty())); final List sentMessages = messageBodyCaptor.getAllValues().stream() - .map(Optional::orElseThrow) + .map(Optional::get) .map(messageBytes -> { try { return Envelope.parseFrom(messageBytes); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 9b8d44963..f9c9d215a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -5,39 +5,54 @@ package org.whispersystems.textsecuregcm.websocket; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import com.google.common.net.HttpHeaders; import com.google.protobuf.ByteString; -import io.lettuce.core.RedisCommandTimeoutException; +import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.auth.basic.BasicCredentials; import io.lettuce.core.RedisException; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.Arrays; +import java.time.Instant; +import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.InOrder; +import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -46,43 +61,47 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; +import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.storage.ConflictingMessageConsumerException; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.MessageStream; -import org.whispersystems.textsecuregcm.storage.MessageStreamEntry; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import reactor.adapter.JdkFlowAdapter; +import org.whispersystems.websocket.session.WebSocketSessionContext; import reactor.core.publisher.Flux; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; -import reactor.util.retry.Retry; class WebSocketConnectionTest { + private static final String VALID_E164 = "+14152222222"; + private static final UUID VALID_UUID = UUID.randomUUID(); + + private static final int SOURCE_DEVICE_ID = 1; + + private static final String VALID_PASSWORD = "secure"; + + private AccountAuthenticator accountAuthenticator; + private AccountsManager accountsManager; private Account account; private Device device; + private UpgradeRequest upgradeRequest; private MessagesManager messagesManager; private ReceiptSender receiptSender; private Scheduler messageDeliveryScheduler; private ClientReleaseManager clientReleaseManager; - private static final int MAX_RETRIES = 3; - private static final Retry RETRY_SPEC = Retry.backoff(MAX_RETRIES, Duration.ofMillis(5)) - .maxBackoff(Duration.ofMillis(20)) - .filter(throwable -> !WebSocketConnection.isConnectionClosedException(throwable)); - - private static final int SOURCE_DEVICE_ID = 1; - @BeforeEach void setup() { + accountAuthenticator = mock(AccountAuthenticator.class); + accountsManager = mock(AccountsManager.class); account = mock(Account.class); device = mock(Device.class); + upgradeRequest = mock(UpgradeRequest.class); messagesManager = mock(MessagesManager.class); receiptSender = mock(ReceiptSender.class); messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); @@ -95,7 +114,515 @@ class WebSocketConnectionTest { messageDeliveryScheduler.dispose(); } - private WebSocketConnection buildWebSocketConnection(final WebSocketClient client) { + @Test + void testCredentials() throws Exception { + WebSocketAccountAuthenticator webSocketAuthenticator = + new WebSocketAccountAuthenticator(accountAuthenticator); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, + new MessageMetrics(Duration.ofDays(30)), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), + mock(RedisMessageAvailabilityManager.class), mock(DisconnectionRequestManager.class), + messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), + mock(ExperimentEnrollmentManager.class)); + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + + when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_E164, VALID_PASSWORD)))) + .thenReturn(Optional.of(new AuthenticatedDevice(VALID_UUID, Device.PRIMARY_ID, Instant.now()))); + + Optional account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null)); + when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null)); + + final WebSocketClient webSocketClient = mock(WebSocketClient.class); + when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); + when(sessionContext.getClient()).thenReturn(webSocketClient); + + // authenticated - valid user + connectListener.onWebSocketConnect(sessionContext); + + verify(sessionContext, times(1)).addWebsocketClosedListener( + any(WebSocketSessionContext.WebSocketEventListener.class)); + + // unauthenticated + when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); + account = webSocketAuthenticator.authenticate(upgradeRequest); + assertFalse(account.isPresent()); + + connectListener.onWebSocketConnect(sessionContext); + verify(sessionContext, times(2)).addWebsocketClosedListener( + any(WebSocketSessionContext.WebSocketEventListener.class)); + + verifyNoMoreInteractions(messagesManager); + } + + @Test + void testOpen() { + + UUID accountUuid = UUID.randomUUID(); + UUID senderOneUuid = UUID.randomUUID(); + UUID senderTwoUuid = UUID.randomUUID(); + + List outgoingMessages = List.of(createMessage(senderOneUuid, accountUuid, 1111, "first"), + createMessage(senderOneUuid, accountUuid, 2222, "second"), + createMessage(senderTwoUuid, accountUuid, 3333, "third")); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + + final Device sender1device = mock(Device.class); + + List sender1devices = List.of(sender1device); + + Account sender1 = mock(Account.class); + when(sender1.getDevices()).thenReturn(sender1devices); + + when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); + when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); + + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + + String userAgent = HttpHeaders.USER_AGENT; + + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenReturn(Flux.fromIterable(outgoingMessages)); + + final List> futures = new LinkedList<>(); + final WebSocketClient client = mock(WebSocketClient.class); + + when(client.getUserAgent()).thenReturn(userAgent); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), nullable(List.class), any())) + .thenAnswer(invocation -> { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + }); + + WebSocketConnection connection = webSocketConnection(client); + + connection.start(); + verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), + any()); + + assertEquals(3, futures.size()); + + WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); + when(response.getStatus()).thenReturn(200); + futures.get(1).complete(response); + + futures.get(0).completeExceptionally(new IOException()); + futures.get(2).completeExceptionally(new IOException()); + + verify(messagesManager, times(1)).delete(eq(accountUuid), eq(device), + eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); + verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)), + eq(2222L)); + + connection.stop(); + verify(client).close(anyInt(), anyString()); + } + + @Test + public void testOnlineSend() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), argThat(d -> d.getId() == Device.PRIMARY_ID), anyBoolean())) + .thenReturn(Flux.empty()) + .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) + .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second"))) + .thenReturn(Flux.empty()); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final AtomicInteger sendCounter = new AtomicInteger(0); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + synchronized (sendCounter) { + sendCounter.incrementAndGet(); + sendCounter.notifyAll(); + } + + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + // This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of + // messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded + // future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize + // or wait for anything. + connection.start(); + + connection.handleNewMessageAvailable(); + + synchronized (sendCounter) { + while (sendCounter.get() < 1) { + sendCounter.wait(); + } + } + + connection.handleNewMessageAvailable(); + + synchronized (sendCounter) { + while (sendCounter.get() < 2) { + sendCounter.wait(); + } + } + }); + + verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + } + + @Test + void testPendingSend() { + final UUID accountUuid = UUID.randomUUID(); + final UUID senderTwoUuid = UUID.randomUUID(); + + final Envelope firstMessage = Envelope.newBuilder() + .setServerGuid(UUID.randomUUID().toString()) + .setSourceServiceId(UUID.randomUUID().toString()) + .setDestinationServiceId(accountUuid.toString()) + .setUpdatedPni(UUID.randomUUID().toString()) + .setClientTimestamp(System.currentTimeMillis()) + .setSourceDevice(1) + .setType(Envelope.Type.CIPHERTEXT) + .build(); + + final Envelope secondMessage = Envelope.newBuilder() + .setServerGuid(UUID.randomUUID().toString()) + .setSourceServiceId(senderTwoUuid.toString()) + .setDestinationServiceId(accountUuid.toString()) + .setClientTimestamp(System.currentTimeMillis()) + .setSourceDevice(2) + .setType(Envelope.Type.CIPHERTEXT) + .build(); + + final List pendingMessages = List.of(firstMessage, secondMessage); + + final byte deviceId = 2; + when(device.getId()).thenReturn(deviceId); + + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + + final Device sender1device = mock(Device.class); + + List sender1devices = List.of(sender1device); + + Account sender1 = mock(Account.class); + when(sender1.getDevices()).thenReturn(sender1devices); + + when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); + when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); + + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); + + String userAgent = HttpHeaders.USER_AGENT; + + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenReturn(Flux.fromIterable(pendingMessages)); + + final List> futures = new LinkedList<>(); + final WebSocketClient client = mock(WebSocketClient.class); + + when(client.getUserAgent()).thenReturn(userAgent); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenAnswer((Answer>) invocationOnMock -> { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + }); + + WebSocketConnection connection = webSocketConnection(client); + + connection.start(); + + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); + + assertEquals(futures.size(), 2); + + WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); + when(response.getStatus()).thenReturn(200); + futures.get(1).complete(response); + futures.get(0).completeExceptionally(new IOException()); + + verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), + eq(secondMessage.getClientTimestamp())); + + connection.stop(); + verify(client).close(anyInt(), anyString()); + } + + @Test + void testProcessStoredMessageConcurrency() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + final AtomicBoolean threadWaiting = new AtomicBoolean(false); + final AtomicBoolean returnMessageList = new AtomicBoolean(false); + + when( + messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenAnswer(invocation -> { + synchronized (threadWaiting) { + threadWaiting.set(true); + threadWaiting.notifyAll(); + } + + synchronized (returnMessageList) { + while (!returnMessageList.get()) { + returnMessageList.wait(); + } + } + + return Flux.empty(); + }); + + final Thread[] threads = new Thread[10]; + final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + connection.processStoredMessages(); + unblockedThreadsLatch.countDown(); + }); + + threads[i].start(); + } + + unblockedThreadsLatch.await(); + + synchronized (threadWaiting) { + while (!threadWaiting.get()) { + threadWaiting.wait(); + } + } + + synchronized (returnMessageList) { + returnMessageList.set(true); + returnMessageList.notifyAll(); + } + + for (final Thread thread : threads) { + thread.join(); + } + }); + + verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), any(), eq(false)); + } + + @Test + void testProcessStoredMessagesMultiplePages() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + when(account.getNumber()).thenReturn("+18005551234"); + final UUID accountUuid = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + final List firstPageMessages = + List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), + createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); + + final List secondPageMessages = + List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); + + when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) + .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream()))); + + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> CompletableFuture.completedFuture(successResponse)); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + connection.processStoredMessages(); + queueEmptyLatch.await(); + }); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + @Test + void testProcessStoredMessagesMultiplePagesBackpressure() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + when(account.getNumber()).thenReturn("+18005551234"); + final UUID accountUuid = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + // Create two publishers, each with >2x WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY messages + final TestPublisher firstPublisher = TestPublisher.createCold(); + final List firstPublisherMessages = IntStream.range(1, + 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 23) + .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) + .toList(); + + final TestPublisher secondPublisher = TestPublisher.createCold(); + final List secondPublisherMessages = IntStream.range(firstPublisherMessages.size(), + firstPublisherMessages.size() + 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 73) + .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) + .toList(); + + final Flux allMessages = Flux.concat(firstPublisher, secondPublisher); + when(messagesManager.getMessagesForDeviceReactive(accountUuid, device, false)) + .thenReturn(allMessages); + + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + + final Queue> pendingClientAcks = new LinkedList<>(); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + final CompletableFuture pendingAck = new CompletableFuture<>(); + pendingClientAcks.add(pendingAck); + return pendingAck; + }); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + // start processing + connection.processStoredMessages(); + + firstPublisher.assertWasRequested(); + // emit all messages from the first publisher + firstPublisher.emit(firstPublisherMessages.toArray(new Envelope[]{})); + // nothing should be requested from the second publisher, because max concurrency is less than the number emitted, + // and none have completed + secondPublisher.assertWasNotRequested(); + // there should only be MESSAGE_SENDER_MAX_CONCURRENCY pending client acknowledgements + assertEquals(WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY, pendingClientAcks.size()); + + while (!pendingClientAcks.isEmpty()) { + pendingClientAcks.poll().complete(successResponse); + } + + secondPublisher.assertWasRequested(); + secondPublisher.emit(secondPublisherMessages.toArray(new Envelope[0])); + + while (!pendingClientAcks.isEmpty()) { + pendingClientAcks.poll().complete(successResponse); + } + + queueEmptyLatch.await(); + }); + + verify(client, times(firstPublisherMessages.size() + secondPublisherMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + @Test + void testProcessStoredMessagesContainsSenderUuid() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + when(account.getNumber()).thenReturn("+18005551234"); + final UUID accountUuid = UUID.randomUUID(); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + final UUID senderUuid = UUID.randomUUID(); + final List messages = List.of( + createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); + + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenReturn(Flux.fromIterable(messages)) + .thenReturn(Flux.empty()); + + when(messagesManager.delete(eq(accountUuid), eq(device), any(UUID.class), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer( + invocation -> CompletableFuture.completedFuture(successResponse)); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + connection.processStoredMessages(); + queueEmptyLatch.await(); + }); + + verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), + argThat(argument -> { + if (argument.isEmpty()) { + return false; + } + + final byte[] body = argument.get(); + try { + final Envelope envelope = Envelope.parseFrom(body); + if (!envelope.hasSourceServiceId() || envelope.getSourceServiceId().length() == 0) { + return false; + } + return envelope.getSourceServiceId().equals(senderUuid.toString()); + } catch (InvalidProtocolBufferException e) { + return false; + } + })); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + private WebSocketConnection webSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(Duration.ofDays(30)), @@ -107,337 +634,191 @@ class WebSocketConnectionTest { Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class), - RETRY_SPEC); + mock(ExperimentEnrollmentManager.class)); } @Test - void testSendMessages() { - - final UUID destinationAccountIdentifier = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - final Envelope successfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Success"); - final Envelope secondSuccessfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "Second success"); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( - new MessageStreamEntry.Envelope(successfulMessage), - new MessageStreamEntry.QueueEmpty(), - new MessageStreamEntry.Envelope(secondSuccessfulMessage)))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - + void testProcessStoredMessagesSingleEmptyCall() { final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.isOpen()).thenReturn(true); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); - - verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage)))); - - verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage)))); - - verify(messageStream).acknowledgeMessage(successfulMessage); - verify(messageStream).acknowledgeMessage(secondSuccessfulMessage); - - verify(receiptSender) - .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), - deviceId, - AciServiceIdentifier.valueOf(successfulMessage.getSourceServiceId()), - successfulMessage.getClientTimestamp()); - - verify(receiptSender) - .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), - deviceId, - AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()), - secondSuccessfulMessage.getClientTimestamp()); - - webSocketConnection.stop(); - - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); - verify(client).close(eq(1000), anyString()); - } - - @Test - void testSendMessagesWithError() { - - final UUID destinationAccountIdentifier = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - final Envelope successfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Success"); - final Envelope failedMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "Failed"); - final Envelope secondSuccessfulMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 3, "Second success"); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( - new MessageStreamEntry.Envelope(successfulMessage), - new MessageStreamEntry.Envelope(failedMessage), - new MessageStreamEntry.QueueEmpty(), - new MessageStreamEntry.Envelope(secondSuccessfulMessage)))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketClient client = mock(WebSocketClient.class); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.isOpen()).thenReturn(true); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage))))) - .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())); - - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); - - verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage)))); - - verify(client, timeout(500).times(MAX_RETRIES + 1)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage)))); - - verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage)))); - - verify(messageStream).acknowledgeMessage(successfulMessage); - verify(messageStream).acknowledgeMessage(secondSuccessfulMessage); - - verify(receiptSender) - .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), - deviceId, - AciServiceIdentifier.valueOf(successfulMessage.getSourceServiceId()), - successfulMessage.getClientTimestamp()); - - verify(receiptSender, never()) - .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), - deviceId, - AciServiceIdentifier.valueOf(failedMessage.getSourceServiceId()), - failedMessage.getClientTimestamp()); - - verify(receiptSender) - .sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier), - deviceId, - AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()), - secondSuccessfulMessage.getClientTimestamp()); - - verify(client, timeout(500)).close(eq(1011), anyString()); - verify(client, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); - } - - @Test - void testQueueEmptySignalOrder() { - - final UUID destinationAccountIdentifier = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - final Envelope initialMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Initial message"); - final Envelope afterQueueDrainMessage = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 2, "After queue drained"); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just( - new MessageStreamEntry.Envelope(initialMessage), - new MessageStreamEntry.QueueEmpty(), - new MessageStreamEntry.Envelope(afterQueueDrainMessage)))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketClient client = mock(WebSocketClient.class); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenAnswer(_ -> CompletableFuture.supplyAsync(() -> successResponse, - CompletableFuture.delayedExecutor(100, TimeUnit.MILLISECONDS))); - - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); - - final InOrder inOrder = inOrder(client, messageStream); - - // Sending the initial message will succeed after a delay, at which point we'll acknowledge the message. Make sure - // we wait for that process to complete before sending the "queue empty" signal - inOrder.verify(messageStream, timeout(1_000)).acknowledgeMessage(initialMessage); - inOrder.verify(client, timeout(1_000)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); - - webSocketConnection.stop(); - verify(client).close(eq(1000), anyString()); - } - - @Test - void testConflictingConsumerSignalOrder() { - - final UUID destinationAccountIdentifier = UUID.randomUUID(); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(destinationAccountIdentifier); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - final Envelope message = createMessage(UUID.randomUUID(), destinationAccountIdentifier, 1, "Initial message"); - - final MessageStream messageStream = mock(MessageStream.class); - final TestPublisher testPublisher = TestPublisher.createCold(); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(testPublisher)); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketClient client = mock(WebSocketClient.class); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenReturn(new CompletableFuture<>()); - - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); - - testPublisher.next(new MessageStreamEntry.Envelope(message)); - testPublisher.error(new ConflictingMessageConsumerException()); - - final InOrder inOrder = inOrder(client, messageStream); - - // A "conflicting consumer" should close the socket as soon as possible (i.e. even if messages are still getting - // processed) - inOrder.verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body -> - body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(message)))); - - verify(client).close(eq(4409), anyString()); - } - - @Test - void testSendMessagesEmptyQueue() { final UUID accountUuid = UUID.randomUUID(); + when(account.getNumber()).thenReturn("+18005551234"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.just(new MessageStreamEntry.QueueEmpty()))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); - final WebSocketClient client = mock(WebSocketClient.class); - when(client.isOpen()).thenReturn(true); + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + connection.processStoredMessages(); - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - - webSocketConnection.start(); - - verify(client, timeout(1_000)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } @Test - void testSendMessagesConflictingConsumer() { + public void testRequeryOnStateMismatch() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); final UUID accountUuid = UUID.randomUUID(); + when(account.getNumber()).thenReturn("+18005551234"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.error(new ConflictingMessageConsumerException()))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); - - final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); + final List firstPageMessages = + List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), + createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); - verify(client, timeout(1_000)).close(eq(4409), anyString()); + final List secondPageMessages = + List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); + + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(Flux.fromIterable(firstPageMessages)) + .thenReturn(Flux.fromIterable(secondPageMessages)) + .thenReturn(Flux.empty()); + + when(messagesManager.delete(eq(accountUuid), eq(device), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + connection.handleNewMessageAvailable(); + + return CompletableFuture.completedFuture(successResponse); + }); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + connection.processStoredMessages(); + + queueEmptyLatch.await(); + }); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testSendMessagesRetrievalException(final boolean clientOpen) { + @Test + void testProcessCachedMessagesOnly() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + final UUID accountUuid = UUID.randomUUID(); + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(Flux.empty()); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + + verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); + + connection.handleNewMessageAvailable(); + + verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true); + } + + @Test + void testProcessDatabaseMessagesAfterPersist() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = webSocketConnection(client); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(Flux.empty()); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + connection.handleMessagesPersisted(); + + verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); + } + + @Test + void testRetrieveMessageException() { + UUID accountUuid = UUID.randomUUID(); + when(device.getId()).thenReturn((byte) 2); + + when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.error(new RedisException("OH NO")))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenReturn(Flux.error(new RedisException("OH NO"))); final WebSocketClient client = mock(WebSocketClient.class); - when(client.isOpen()).thenReturn(clientOpen); + when(client.isOpen()).thenReturn(true); - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); - webSocketConnection.start(); + WebSocketConnection connection = webSocketConnection(client); + connection.start(); - if (clientOpen) { - verify(client).close(eq(1011), anyString()); - } else { - verify(client, never()).close(anyInt(), any()); - } + verify(client).close(eq(1011), anyString()); + } + + @Test + void testRetrieveMessageExceptionClientDisconnected() { + UUID accountUuid = UUID.randomUUID(); + + when(device.getId()).thenReturn((byte) 2); + + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); + + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) + .thenReturn(Flux.error(new RedisException("OH NO"))); + + final WebSocketClient client = mock(WebSocketClient.class); + when(client.isOpen()).thenReturn(false); + + WebSocketConnection connection = webSocketConnection(client); + connection.start(); + + verify(client, never()).close(anyInt(), anyString()); } @Test @@ -447,31 +828,28 @@ class WebSocketConnectionTest { final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); + when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final int totalMessages = 1000; - final TestPublisher testPublisher = TestPublisher.createCold(); - final Flux flux = Flux.from(testPublisher); + final TestPublisher testPublisher = TestPublisher.createCold(); + final Flux flux = Flux.from(testPublisher); - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(flux)); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + WebSocketConnection connection = webSocketConnection(client); - webSocketConnection.start(); + connection.start(); StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); @@ -480,7 +858,7 @@ class WebSocketConnectionTest { .thenRequest(totalMessages * 2) .then(() -> { for (long i = 0; i < totalMessages; i++) { - testPublisher.next(new MessageStreamEntry.Envelope(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i))); + testPublisher.next(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i)); } testPublisher.complete(); }) @@ -499,46 +877,40 @@ class WebSocketConnectionTest { final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); + when(account.getNumber()).thenReturn("+14152222222"); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final AtomicBoolean canceled = new AtomicBoolean(); - final Flux flux = Flux.create(s -> { + final Flux flux = Flux.create(s -> { s.onRequest(n -> { // the subscriber should request more than 1 message, but we will only send one, so that // we are sure the subscriber is waiting for more when we stop the connection assert n > 1; - s.next(new MessageStreamEntry.Envelope(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))); + s.next(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")); }); s.onCancel(() -> canceled.set(true)); }); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(flux)); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(accountUuid, device)).thenReturn(messageStream); - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(device), anyBoolean())) + .thenReturn(flux); final WebSocketClient client = mock(WebSocketClient.class); when(client.isOpen()).thenReturn(true); + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + when(messagesManager.delete(any(), any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(Optional.empty())); - final WebSocketConnection webSocketConnection = buildWebSocketConnection(client); + WebSocketConnection connection = webSocketConnection(client); - webSocketConnection.start(); + connection.start(); verify(client).sendRequest(any(), any(), any(), any()); // close the connection before the publisher completes - webSocketConnection.stop(); + connection.stop(); StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); @@ -552,61 +924,7 @@ class WebSocketConnectionTest { .verify(); } - @Test - void testRetryOnError() { - final UUID accountIdentifier = UUID.randomUUID(); - - final List outgoingMessages = List.of(createMessage(accountIdentifier, accountIdentifier, 1111, "first")); - - final byte deviceId = 2; - when(device.getId()).thenReturn(deviceId); - - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); - - final MessageStream messageStream = mock(MessageStream.class); - - when(messageStream.getMessages()) - .thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(outgoingMessages) - .map(MessageStreamEntry.Envelope::new))); - - when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null)); - - when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device)) - .thenReturn(messageStream); - - when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); - - final WebSocketClient client = mock(WebSocketClient.class); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) - .thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException())) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - final WebSocketConnection connection = buildWebSocketConnection(client); - - connection.start(); - - verify(client, timeout(500).times(3)) - .sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); - - verify(messageStream, timeout(500)).acknowledgeMessage(outgoingMessages.getFirst()); - - verify(receiptSender, timeout(500)) - .sendReceipt(new AciServiceIdentifier(accountIdentifier), deviceId, new AciServiceIdentifier(accountIdentifier), 1111L); - - connection.stop(); - verify(client).close(eq(1000), anyString()); - } - - private static Envelope createMessage(final UUID senderUuid, - final UUID destinationUuid, - final long timestamp, - final String content) { - + private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long timestamp, String content) { return Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) .setType(Envelope.Type.CIPHERTEXT)