Persist messages in batches.

This commit is contained in:
Jon Chambers
2020-09-23 12:23:34 -04:00
committed by Jon Chambers
parent 6041a9d094
commit fc71ced660
6 changed files with 118 additions and 106 deletions

View File

@@ -5,12 +5,14 @@ import io.lettuce.core.cluster.SlotHash;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
@@ -18,14 +20,15 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -67,17 +70,19 @@ public class MessagePersisterTest extends AbstractRedisClusterTest {
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, scheduledExecutorService, PERSIST_DELAY);
doAnswer(invocation -> {
final String destination = invocation.getArgument(0, String.class);
final UUID destinationUuid = invocation.getArgument(1, UUID.class);
final MessageProtos.Envelope message = invocation.getArgument(2, MessageProtos.Envelope.class);
final UUID messageGuid = invocation.getArgument(3, UUID.class);
final long deviceId = invocation.getArgument(4, Long.class);
final String destination = invocation.getArgument(0, String.class);
final UUID destinationUuid = invocation.getArgument(1, UUID.class);
final long deviceId = invocation.getArgument(2, Long.class);
final List<MessageProtos.Envelope> messages = invocation.getArgument(3, List.class);
messagesDatabase.store(messageGuid, message, destination, deviceId);
messagesCache.remove(destinationUuid, deviceId, messageGuid);
messagesDatabase.store(messages, destination, deviceId);
for (final MessageProtos.Envelope message : messages) {
messagesCache.remove(destinationUuid, deviceId, UUID.fromString(message.getServerGuid()));
}
return null;
}).when(messagesManager).persistMessage(anyString(), any(UUID.class), any(MessageProtos.Envelope.class), any(UUID.class), anyLong());
}).when(messagesManager).persistMessages(anyString(), any(UUID.class), anyLong(), any());
}
@Override
@@ -109,7 +114,10 @@ public class MessagePersisterTest extends AbstractRedisClusterTest {
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
verify(messagesDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
@@ -123,7 +131,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest {
messagePersister.persistNextQueues(now);
verify(messagesDatabase, never()).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong());
verify(messagesDatabase, never()).store(any(), anyString(), anyLong());
}
@Test
@@ -151,7 +159,10 @@ public class MessagePersisterTest extends AbstractRedisClusterTest {
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
verify(messagesDatabase, times(queueCount * messagesPerQueue)).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong());
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), anyString(), anyLong());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@SuppressWarnings("SameParameterValue")

View File

@@ -44,9 +44,8 @@ public class MessagesTest {
@Test
public void testStore() throws SQLException {
Envelope envelope = generateEnvelope();
UUID guid = UUID.randomUUID();
messages.store(guid, envelope, "+14151112222", 1);
messages.store(List.of(envelope), "+14151112222", 1);
PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?");
statement.setString(1, "+14151112222");
@@ -54,7 +53,7 @@ public class MessagesTest {
ResultSet resultSet = statement.executeQuery();
assertThat(resultSet.next()).isTrue();
assertThat(resultSet.getString("guid")).isEqualTo(guid.toString());
assertThat(resultSet.getString("guid")).isEqualTo(envelope.getServerGuid());
assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber());
assertThat(resultSet.getString("relay")).isNullOrEmpty();
assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp());
@@ -71,36 +70,29 @@ public class MessagesTest {
@Test
public void testLoad() {
List<MessageToStore> inserted = new ArrayList<>(50);
List<Envelope> inserted = insertRandom("+14151112222", 1);
for (int i=0;i<50;i++) {
MessageToStore message = generateMessageToStore();
inserted.add(message);
messages.store(message.guid, message.envelope, "+14151112222", 1);
}
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
inserted.sort(Comparator.comparingLong(Envelope::getTimestamp));
List<OutgoingMessageEntity> retrieved = messages.load("+14151112222", 1);
assertThat(retrieved.size()).isEqualTo(inserted.size());
for (int i=0;i<retrieved.size();i++) {
verifyExpected(retrieved.get(i), inserted.get(i).envelope, inserted.get(i).guid);
verifyExpected(retrieved.get(i), inserted.get(i), UUID.fromString(inserted.get(i).getServerGuid()));
}
}
@Test
public void removeBySourceDestinationTimestamp() {
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
List<MessageToStore> unrelated = insertRandom("+14151114444", 3);
MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", 1, toRemove.envelope.getSource(), toRemove.envelope.getTimestamp());
List<Envelope> inserted = insertRandom("+14151112222", 1);
List<Envelope> unrelated = insertRandom("+14151114444", 3);
Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", 1, toRemove.getSource(), toRemove.getTimestamp());
assertThat(removed.isPresent()).isTrue();
verifyExpected(removed.get(), toRemove.envelope, toRemove.guid);
verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid()));
verifyInTact(inserted, "+14151112222", 1);
verifyInTact(unrelated, "+14151114444", 3);
@@ -108,13 +100,13 @@ public class MessagesTest {
@Test
public void removeByDestinationGuid() {
List<MessageToStore> unrelated = insertRandom("+14151113333", 2);
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
MessageToStore toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", toRemove.guid);
List<Envelope> unrelated = insertRandom("+14151113333", 2);
List<Envelope> inserted = insertRandom("+14151112222", 1);
Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1));
Optional<OutgoingMessageEntity> removed = messages.remove("+14151112222", UUID.fromString(toRemove.getServerGuid()));
assertThat(removed.isPresent()).isTrue();
verifyExpected(removed.get(), toRemove.envelope, toRemove.guid);
verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid()));
verifyInTact(inserted, "+14151112222", 1);
verifyInTact(unrelated, "+14151113333", 2);
@@ -122,10 +114,10 @@ public class MessagesTest {
@Test
public void removeByDestinationRowId() {
List<MessageToStore> unrelatedInserted = insertRandom("+14151111111", 1);
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
List<Envelope> unrelatedInserted = insertRandom("+14151111111", 1);
List<Envelope> inserted = insertRandom("+14151112222", 1);
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
inserted.sort(Comparator.comparingLong(Envelope::getTimestamp));
List<OutgoingMessageEntity> retrieved = messages.load("+14151112222", 1);
@@ -141,9 +133,8 @@ public class MessagesTest {
@Test
public void testLoadEmpty() {
List<MessageToStore> inserted = insertRandom("+14151112222", 1);
List<OutgoingMessageEntity> loaded = messages.load("+14159999999", 1);
assertThat(loaded.isEmpty()).isTrue();
insertRandom("+14151112222", 1);
assertThat(messages.load("+14159999999", 1).isEmpty()).isTrue();
}
@Test
@@ -151,7 +142,7 @@ public class MessagesTest {
insertRandom("+14151112222", 1);
insertRandom("+14151112222", 2);
List<MessageToStore> unrelated = insertRandom("+14151111111", 1);
List<Envelope> unrelated = insertRandom("+14151111111", 1);
messages.clear("+14151112222");
@@ -163,9 +154,9 @@ public class MessagesTest {
@Test
public void testClearDestinationDevice() {
insertRandom("+14151112222", 1);
List<MessageToStore> inserted = insertRandom("+14151112222", 2);
List<Envelope> inserted = insertRandom("+14151112222", 2);
List<MessageToStore> unrelated = insertRandom("+14151111111", 1);
List<Envelope> unrelated = insertRandom("+14151111111", 1);
messages.clear("+14151112222", 1);
@@ -177,33 +168,32 @@ public class MessagesTest {
@Test
public void testVacuum() {
List<MessageToStore> inserted = insertRandom("+14151112222", 2);
List<Envelope> inserted = insertRandom("+14151112222", 2);
messages.vacuum();
verifyInTact(inserted, "+14151112222", 2);
}
private List<MessageToStore> insertRandom(String destination, int destinationDevice) {
List<MessageToStore> inserted = new ArrayList<>(50);
private List<Envelope> insertRandom(String destination, int destinationDevice) {
List<Envelope> inserted = new ArrayList<>(50);
for (int i=0;i<50;i++) {
MessageToStore message = generateMessageToStore();
inserted.add(message);
messages.store(message.guid, message.envelope, destination, destinationDevice);
inserted.add(generateEnvelope());
}
messages.store(inserted, destination, destinationDevice);
return inserted;
}
private void verifyInTact(List<MessageToStore> inserted, String destination, int destinationDevice) {
inserted.sort(Comparator.comparingLong(o -> o.envelope.getTimestamp()));
private void verifyInTact(List<Envelope> inserted, String destination, int destinationDevice) {
inserted.sort(Comparator.comparingLong(Envelope::getTimestamp));
List<OutgoingMessageEntity> retrieved = messages.load(destination, destinationDevice);
assertThat(retrieved.size()).isEqualTo(inserted.size());
for (int i=0;i<retrieved.size();i++) {
verifyExpected(retrieved.get(i), inserted.get(i).envelope, inserted.get(i).guid);
verifyExpected(retrieved.get(i), inserted.get(i), UUID.fromString(inserted.get(i).getServerGuid()));
}
}
@@ -220,10 +210,6 @@ public class MessagesTest {
assertThat(retrieved.getSourceDevice()).isEqualTo(inserted.getSourceDevice());
}
private MessageToStore generateMessageToStore() {
return new MessageToStore(UUID.randomUUID(), generateEnvelope());
}
private Envelope generateEnvelope() {
Random random = new Random();
byte[] content = new byte[256];
@@ -233,6 +219,7 @@ public class MessagesTest {
Arrays.fill(legacy, (byte)random.nextInt(255));
return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setSourceDevice(random.nextInt(10000))
.setSource("testSource" + random.nextInt())
.setTimestamp(serialTimestamp++)
@@ -243,14 +230,4 @@ public class MessagesTest {
.setServerGuid(UUID.randomUUID().toString())
.build();
}
private static class MessageToStore {
private final UUID guid;
private final Envelope envelope;
private MessageToStore(UUID guid, Envelope envelope) {
this.guid = guid;
this.envelope = envelope;
}
}
}

View File

@@ -63,6 +63,8 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private long serialTimestamp = System.currentTimeMillis();
@Before
public void setupAccountsDao() {
}
@@ -108,12 +110,17 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
messages.store(messageGuid, envelope, account.getNumber(), device.getId());
expectedMessages.add(envelope.toBuilder().clearServerGuid().build());
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope.toBuilder().clearServerGuid().build());
}
messages.store(persistedMessages, account.getNumber(), device.getId());
}
for (int i = 0; i < cachedMessageCount; i++) {
@@ -172,9 +179,14 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
for (int i = 0; i < persistedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
messages.store(messageGuid, generateRandomMessage(messageGuid), account.getNumber(), device.getId());
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
persistedMessages.add(generateRandomMessage(UUID.randomUUID()));
}
messages.store(persistedMessages, account.getNumber(), device.getId());
}
for (int i = 0; i < cachedMessageCount; i++) {
@@ -191,9 +203,11 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
}
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {
final long timestamp = serialTimestamp++;
return MessageProtos.Envelope.newBuilder()
.setTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setTimestamp(timestamp)
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())