diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 23f7d8a93..1fe5c6b05 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -300,6 +300,13 @@ public class MessagesCache { } + public CompletableFuture hasMessagesAsync(final UUID destinationUuid, final byte destinationDevice) { + return redisCluster.withBinaryCluster(connection -> + connection.async().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) + .thenApply(cardinality -> cardinality > 0)) + .toCompletableFuture(); + } + public Publisher get(final UUID destinationUuid, final byte destinationDevice) { final long earliestAllowableEphemeralTimestamp = diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index cc2296e66..3f049a6fb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -32,6 +32,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.push.RedisMessageAvailabilityManager; import org.whispersystems.textsecuregcm.util.Pair; import reactor.core.observability.micrometer.Micrometer; @@ -51,6 +52,9 @@ public class MessagesManager { private static final Counter PERSIST_MESSAGE_BYTES_COUNTER = Metrics.counter( name(MessagesManager.class, "persistMessageBytes")); + private static final String MAY_HAVE_MESSAGES_COUNTER_NAME = + MetricsUtil.name(MessagesManager.class, "mayHaveMessages"); + private final MessagesDynamoDb messagesDynamoDb; private final MessagesCache messagesCache; private final RedisMessageAvailabilityManager redisMessageAvailabilityManager; @@ -178,6 +182,28 @@ public class MessagesManager { return messagesDynamoDb.mayHaveMessages(destinationUuid, destinationDevice); } + public CompletableFuture mayHaveMessages(final UUID destinationUuid, final Device destinationDevice) { + return messagesCache.hasMessagesAsync(destinationUuid, destinationDevice.getId()) + .thenCombine(messagesDynamoDb.mayHaveMessages(destinationUuid, destinationDevice), + (mayHaveCachedMessages, mayHavePersistedMessages) -> { + final String outcome; + + if (mayHaveCachedMessages && mayHavePersistedMessages) { + outcome = "both"; + } else if (mayHaveCachedMessages) { + outcome = "cached"; + } else if (mayHavePersistedMessages) { + outcome = "persisted"; + } else { + outcome = "none"; + } + + Metrics.counter(MAY_HAVE_MESSAGES_COUNTER_NAME, "outcome", outcome).increment(); + + return mayHaveCachedMessages || mayHavePersistedMessages; + }); + } + public CompletableFuture mayHaveUrgentPersistedMessages(final UUID destinationUuid, final Device destinationDevice) { return messagesDynamoDb.mayHaveUrgentMessages(destinationUuid, destinationDevice); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 24bf125cf..e6a6a0f7e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -194,6 +194,17 @@ class MessagesCacheTest { assertEquals(messagesToPreserve, get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } + @Test + void testHasMessagesAsync() { + assertFalse(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); + + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); + + assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); + } + @Test void getOldestTimestamp() { final int messageCount = 100; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index c198d551c..fd01c3a9a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -12,6 +12,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -185,6 +186,31 @@ class MessagesManagerTest { any()); } + @ParameterizedTest + @CsvSource({ + "false, false, false", + "false, true, true", + "true, false, true", + "true, true, true" + }) + void mayHaveMessages(final boolean hasCachedMessages, final boolean hasPersistedMessages, final boolean expectMayHaveMessages) { + final UUID accountIdentifier = UUID.randomUUID(); + final Device device = mock(Device.class); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + + when(messagesCache.hasMessagesAsync(accountIdentifier, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(hasCachedMessages)); + + when(messagesDynamoDb.mayHaveMessages(accountIdentifier, device)) + .thenReturn(CompletableFuture.completedFuture(hasPersistedMessages)); + + if (hasCachedMessages) { + verifyNoInteractions(messagesDynamoDb); + } + + assertEquals(expectMayHaveMessages, messagesManager.mayHaveMessages(accountIdentifier, device).join()); + } + @ParameterizedTest @CsvSource({ ",,",