diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 6cf3d56eb..8e5bbec4a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -102,8 +102,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn @VisibleForTesting static final int MESSAGE_SENDER_MAX_CONCURRENCY = 256; - static final Duration DEFAULT_SEND_FUTURES_TIMEOUT = Duration.ofMinutes(5); - private static final Duration CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY = Duration.ofMinutes(1); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -120,8 +118,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn private final Device authenticatedDevice; private final WebSocketClient client; - private final Duration sendFuturesTimeout; - private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); private final AtomicReference storedMessageState = new AtomicReference<>( StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); @@ -140,47 +136,18 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn PERSISTED_NEW_MESSAGES_AVAILABLE } - public WebSocketConnection(ReceiptSender receiptSender, - MessagesManager messagesManager, - MessageMetrics messageMetrics, - PushNotificationManager pushNotificationManager, - PushNotificationScheduler pushNotificationScheduler, - Account authenticatedAccount, - Device authenticatedDevice, - WebSocketClient client, - Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager, - MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, - ExperimentEnrollmentManager experimentEnrollmentManager) { - - this(receiptSender, - messagesManager, - messageMetrics, - pushNotificationManager, - pushNotificationScheduler, - authenticatedAccount, - authenticatedDevice, - client, - DEFAULT_SEND_FUTURES_TIMEOUT, - messageDeliveryScheduler, - clientReleaseManager, - messageDeliveryLoopMonitor, experimentEnrollmentManager); - } - - @VisibleForTesting - WebSocketConnection(ReceiptSender receiptSender, - MessagesManager messagesManager, - MessageMetrics messageMetrics, - PushNotificationManager pushNotificationManager, - PushNotificationScheduler pushNotificationScheduler, - Account authenticatedAccount, - Device authenticatedDevice, - WebSocketClient client, - Duration sendFuturesTimeout, - Scheduler messageDeliveryScheduler, - ClientReleaseManager clientReleaseManager, - MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, - ExperimentEnrollmentManager experimentEnrollmentManager) { + public WebSocketConnection(final ReceiptSender receiptSender, + final MessagesManager messagesManager, + final MessageMetrics messageMetrics, + final PushNotificationManager pushNotificationManager, + final PushNotificationScheduler pushNotificationScheduler, + final Account authenticatedAccount, + final Device authenticatedDevice, + final WebSocketClient client, + final Scheduler messageDeliveryScheduler, + final ClientReleaseManager clientReleaseManager, + final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, + final ExperimentEnrollmentManager experimentEnrollmentManager) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; @@ -190,7 +157,6 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn this.authenticatedAccount = authenticatedAccount; this.authenticatedDevice = authenticatedDevice; this.client = client; - this.sendFuturesTimeout = sendFuturesTimeout; this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; @@ -384,10 +350,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Disconn "websocket"); } }) - .flatMapSequential(envelope -> - Mono.fromFuture(() -> sendMessage(envelope)).timeout(sendFuturesTimeout) - // Note that this will retry both for "send to client" timeouts and failures to delete messages on - // acknowledgement + .flatMapSequential(envelope -> Mono.fromFuture(() -> sendMessage(envelope)) .retryWhen(Retry.backoff(4, Duration.ofSeconds(1)).filter(throwable -> !isConnectionClosedException(throwable))), MESSAGE_SENDER_MAX_CONCURRENCY) .doOnError(this::measureSendMessageErrors) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index a3f4e74fa..abca22849 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -37,6 +37,7 @@ import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -68,6 +69,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +@Timeout(value = 30, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class WebSocketConnectionIntegrationTest { @RegisterExtension @@ -283,105 +285,6 @@ class WebSocketConnectionIntegrationTest { }); } - @Test - void testProcessStoredMessagesSendFutureTimeout() { - final WebSocketConnection webSocketConnection = new WebSocketConnection( - mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()), - new MessageMetrics(), - mock(PushNotificationManager.class), - mock(PushNotificationScheduler.class), - account, - device, - webSocketClient, - Duration.ofSeconds(1), // use a short timeout, so that this test completes quickly - messageDeliveryScheduler, - clientReleaseManager, - mock(MessageDeliveryLoopMonitor.class), - mock(ExperimentEnrollmentManager.class)); - - final int persistedMessageCount = 207; - final int cachedMessageCount = 173; - - final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); - - assertTimeoutPreemptively(Duration.ofSeconds(15), () -> { - - { - final List persistedMessages = new ArrayList<>(persistedMessageCount); - - for (int i = 0; i < persistedMessageCount; i++) { - final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); - persistedMessages.add(envelope); - expectedMessages.add(envelope); - } - - messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device); - } - - for (int i = 0; i < cachedMessageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); - - expectedMessages.add(envelope); - } - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final CompletableFuture neverCompleting = new CompletableFuture<>(); - - // for the first message, return a future that never completes - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) - .thenReturn(neverCompleting) - .thenReturn(CompletableFuture.completedFuture(successResponse)); - - when(webSocketClient.isOpen()).thenReturn(true); - - final AtomicBoolean queueCleared = new AtomicBoolean(false); - - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( - (Answer>) invocation -> { - synchronized (queueCleared) { - queueCleared.set(true); - queueCleared.notifyAll(); - } - - return CompletableFuture.completedFuture(successResponse); - }); - - webSocketConnection.processStoredMessages(); - - synchronized (queueCleared) { - while (!queueCleared.get()) { - queueCleared.wait(); - } - } - - //noinspection unchecked - ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); - - // We expect all of the messages from both pools to be sent, plus one for the future that times out - verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount + 1)) - .sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); - - verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); - - final List sentMessages = messageBodyCaptor.getAllValues().stream() - .map(Optional::get) - .map(messageBytes -> { - try { - return Envelope.parseFrom(messageBytes); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } - }).toList(); - - assertTrue(expectedMessages.containsAll(sentMessages)); - }); - } - private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { final long timestamp = serialTimestamp++; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 2672ceaf1..06e9408dc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -18,7 +18,6 @@ import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -49,7 +48,6 @@ import java.util.stream.Stream; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; @@ -633,7 +631,6 @@ class WebSocketConnectionTest { account, device, client, - Duration.ofSeconds(1), Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), @@ -927,40 +924,6 @@ class WebSocketConnectionTest { .verify(); } - @Test - @Disabled("Slow test") - public void testClientTimeout() { - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = webSocketConnection(client); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - when(client.isOpen()).thenReturn(true); - - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), argThat(d -> d.getId() == Device.PRIMARY_ID), anyBoolean())) - .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) - .thenReturn(Flux.empty()); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) - // This future will never complete and should time out - .thenAnswer(_ -> new CompletableFuture<>()); - - connection.start(); - - connection.handleNewMessageAvailable(); - - verify(client, timeout(30_000).times(5)) - .sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); - - verify(client, timeout(30_000)).close(eq(1011), any()); - } - private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long timestamp, String content) { return Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString())