Wait for message acknowledgement before fetching new messags from Redis/DynamoDB

This commit is contained in:
Jon Chambers
2025-08-14 12:07:40 -04:00
committed by Jon Chambers
parent 194e43926a
commit 8fe87b77e4
7 changed files with 256 additions and 160 deletions

View File

@@ -206,7 +206,10 @@ class RedisDynamoDbMessagePublisherTest {
}
deleteRedisMessage(redisMessage);
messagePublisher.handleMessageAcknowledged();
deleteDynamoDbMessage(dynamoDbMessage);
messagePublisher.handleMessageAcknowledged();
insertRedisMessage(newArrivalRedisMessage);
messagePublisher.handleNewMessageAvailable();
@@ -245,7 +248,10 @@ class RedisDynamoDbMessagePublisherTest {
}
deleteRedisMessage(redisMessage);
messagePublisher.handleMessageAcknowledged();
deleteDynamoDbMessage(dynamoDbMessage);
messagePublisher.handleMessageAcknowledged();
insertDynamoDbMessage(persistedMessage);
messagePublisher.handleMessagesPersisted();
@@ -264,6 +270,41 @@ class RedisDynamoDbMessagePublisherTest {
.verifyTimeout(Duration.ofMillis(500));
}
@Test
void publishMessagesWaitForAcknowledgement() {
final MessageProtos.Envelope dynamoDbMessage = insertDynamoDbMessage(generateRandomMessage());
final MessageProtos.Envelope redisMessage = insertRedisMessage(generateRandomMessage());
final MessageProtos.Envelope persistedMessage = generateRandomMessage();
final RedisDynamoDbMessagePublisher messagePublisher =
new RedisDynamoDbMessagePublisher(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, DESTINATION_SERVICE_IDENTIFIER.uuid(), destinationDevice);
final CountDownLatch queueEmptyCountDownLatch = new CountDownLatch(1);
Thread.ofVirtual().start(() -> {
try {
queueEmptyCountDownLatch.await();
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}
insertDynamoDbMessage(persistedMessage);
messagePublisher.handleMessagesPersisted();
});
StepVerifier.create(JdkFlowAdapter.flowPublisherToFlux(messagePublisher)
.doOnNext(entry -> {
if (entry instanceof MessageStreamEntry.QueueEmpty) {
queueEmptyCountDownLatch.countDown();
}
}))
.expectNext(new MessageStreamEntry.Envelope(dynamoDbMessage))
.expectNext(new MessageStreamEntry.Envelope(redisMessage))
.expectNext(new MessageStreamEntry.QueueEmpty())
.verifyTimeout(Duration.ofMillis(500));
}
@Test
void publishMessagesConsumerConflict() {
final RedisDynamoDbMessagePublisher messagePublisher =

View File

@@ -20,7 +20,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager;
class RedisDynamoDbMessageStreamTest {
@@ -44,9 +43,9 @@ class RedisDynamoDbMessageStreamTest {
redisDynamoDbMessageStream = new RedisDynamoDbMessageStream(messagesDynamoDb,
messagesCache,
mock(RedisMessageAvailabilityManager.class),
ACCOUNT_IDENTIFIER,
device);
device,
mock(RedisDynamoDbMessagePublisher.class));
when(messagesDynamoDb.deleteMessage(any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@@ -74,7 +73,6 @@ class RedisDynamoDbMessageStreamTest {
void acknowledgeMessageRedis() {
final MessageProtos.Envelope message = generateMessage();
final UUID messageGuid = UUID.fromString(message.getServerGuid());
final long serverTimestamp = message.getServerTimestamp();
when(messagesCache.remove(ACCOUNT_IDENTIFIER, DEVICE_ID, messageGuid))
.thenReturn(CompletableFuture.completedFuture(Optional.of(RemovedMessage.fromEnvelope(message))));

View File

@@ -34,6 +34,8 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -67,6 +69,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
@@ -126,8 +129,13 @@ class WebSocketConnectionIntegrationTest {
redisMessageAvailabilityManager.stop();
sharedExecutorService.shutdown();
final Mono<Void> schedulerShutdownMono = messageDeliveryScheduler.disposeGracefully();
//noinspection ResultOfMethodCallIgnored
sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS);
schedulerShutdownMono.timeout(Duration.ofSeconds(2))
.onErrorResume(TimeoutException.class, _ -> Mono.fromRunnable(() -> messageDeliveryScheduler.dispose()))
.block();
}
@ParameterizedTest
@@ -210,6 +218,101 @@ class WebSocketConnectionIntegrationTest {
});
}
@Test
void testProcessStoredMessagesMultipleSegments() {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, redisMessageAvailabilityManager, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
account,
device,
webSocketClient,
messageDeliveryScheduler,
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class)
);
final int persistedMessageCount = 77;
final int cachedMessageCount = 104;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
assertTimeoutPreemptively(Duration.ofSeconds(15), () -> {
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
final AtomicInteger remainingMessages = new AtomicInteger(persistedMessageCount + cachedMessageCount);
final int additionalMessageCount = 67;
when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()))
.thenAnswer(_ -> {
if (remainingMessages.addAndGet(-1) == 60) {
sharedExecutorService.submit(() -> {
for (int i = 0; i < additionalMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
});
}
return CompletableFuture.completedFuture(successResponse);
});
webSocketConnection.start();
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor =
ArgumentCaptor.forClass(Optional.class);
verify(webSocketClient, timeout(10_000))
.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
verify(webSocketClient, timeout(10_000).times(persistedMessageCount + cachedMessageCount + additionalMessageCount))
.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>();
for (final Optional<byte[]> maybeMessageBody : messageBodyCaptor.getAllValues()) {
maybeMessageBody.ifPresent(messageBytes -> {
try {
sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes));
} catch (final InvalidProtocolBufferException e) {
fail("Could not parse sent message");
}
});
}
assertEquals(expectedMessages, sentMessages);
});
}
@Test
void testProcessStoredMessagesClientClosed() {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
@@ -254,8 +357,8 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.failedFuture(new IOException("Connection closed")));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()))
.thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed")));
webSocketConnection.start();
@@ -293,5 +396,4 @@ class WebSocketConnectionIntegrationTest {
.setDestinationServiceId(UUID.randomUUID().toString())
.build();
}
}

View File

@@ -27,7 +27,6 @@ import io.lettuce.core.RedisException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
@@ -66,7 +65,6 @@ import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.util.retry.Retry;
class WebSocketConnectionTest {
@@ -77,11 +75,6 @@ class WebSocketConnectionTest {
private Scheduler messageDeliveryScheduler;
private ClientReleaseManager clientReleaseManager;
private static final int MAX_RETRIES = 3;
private static final Retry RETRY_SPEC = Retry.backoff(MAX_RETRIES, Duration.ofMillis(5))
.maxBackoff(Duration.ofMillis(20))
.filter(throwable -> !WebSocketConnection.isConnectionClosedException(throwable));
private static final int SOURCE_DEVICE_ID = 1;
private static final AtomicInteger ON_ERROR_DROPPED_COUNTER = new AtomicInteger();
@@ -129,8 +122,7 @@ class WebSocketConnectionTest {
Schedulers.immediate(),
clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class),
mock(ExperimentEnrollmentManager.class),
RETRY_SPEC);
mock(ExperimentEnrollmentManager.class));
}
@Test
@@ -249,14 +241,14 @@ class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body ->
body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(successfulMessage))));
verify(client, timeout(500).times(MAX_RETRIES + 1)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body ->
verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body ->
body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(failedMessage))));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body ->
verify(client, never()).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), argThat(body ->
body.isPresent() && Arrays.equals(body.get(), WebSocketConnection.serializeMessage(secondSuccessfulMessage))));
verify(messageStream).acknowledgeMessage(successfulMessage);
verify(messageStream).acknowledgeMessage(secondSuccessfulMessage);
verify(messageStream, never()).acknowledgeMessage(secondSuccessfulMessage);
verify(receiptSender)
.sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier),
@@ -270,7 +262,7 @@ class WebSocketConnectionTest {
AciServiceIdentifier.valueOf(failedMessage.getSourceServiceId()),
failedMessage.getClientTimestamp());
verify(receiptSender)
verify(receiptSender, never())
.sendReceipt(new AciServiceIdentifier(destinationAccountIdentifier),
deviceId,
AciServiceIdentifier.valueOf(secondSuccessfulMessage.getSourceServiceId()),
@@ -574,56 +566,6 @@ class WebSocketConnectionTest {
.verify();
}
@Test
void testRetryOnError() {
final UUID accountIdentifier = UUID.randomUUID();
final List<Envelope> outgoingMessages = List.of(createMessage(accountIdentifier, accountIdentifier, 1111, "first"));
final byte deviceId = 2;
when(device.getId()).thenReturn(deviceId);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
final MessageStream messageStream = mock(MessageStream.class);
when(messageStream.getMessages())
.thenReturn(JdkFlowAdapter.publisherToFlowPublisher(Flux.fromIterable(outgoingMessages)
.map(MessageStreamEntry.Envelope::new)));
when(messageStream.acknowledgeMessage(any())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.getMessages(account.getIdentifier(IdentityType.ACI), device))
.thenReturn(messageStream);
when(messagesManager.mayHaveMessages(any(), any())).thenReturn(CompletableFuture.completedFuture(false));
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException()))
.thenReturn(CompletableFuture.failedFuture(new RedisCommandTimeoutException()))
.thenReturn(CompletableFuture.completedFuture(successResponse));
final WebSocketConnection connection = buildWebSocketConnection(client);
connection.start();
verify(client, timeout(500).times(3))
.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any());
verify(messageStream, timeout(500)).acknowledgeMessage(outgoingMessages.getFirst());
verify(receiptSender, timeout(500))
.sendReceipt(new AciServiceIdentifier(accountIdentifier), deviceId, new AciServiceIdentifier(accountIdentifier), 1111L);
connection.stop();
verify(client).close(eq(1000), anyString());
}
private static Envelope createMessage(final UUID senderUuid,
final UUID destinationUuid,
final long timestamp,