Make Envelope the main unit of currency when working with stored messages

This commit is contained in:
Jon Chambers
2022-07-27 15:43:39 -04:00
committed by Jon Chambers
parent 3e0919106d
commit 3636626e09
9 changed files with 245 additions and 278 deletions

View File

@@ -27,6 +27,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
@@ -43,6 +44,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -57,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
@@ -77,6 +80,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -371,18 +375,22 @@ class MessageControllerTest {
final long timestampTwo = 313388;
final UUID messageGuidOne = UUID.randomUUID();
final UUID messageGuidTwo = UUID.randomUUID();
final UUID sourceUuid = UUID.randomUUID();
final UUID updatedPniOne = UUID.randomUUID();
List<OutgoingMessageEntity> messages = new LinkedList<>() {{
add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0));
}};
List<Envelope> messages = List.of(
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0),
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0)
);
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream()
.map(OutgoingMessageEntity::fromEnvelope)
.toList(), false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/")
@@ -397,7 +405,7 @@ class MessageControllerTest {
assertEquals(response.messages().get(1).timestamp(), timestampTwo);
assertEquals(response.messages().get(0).guid(), messageGuidOne);
assertNull(response.messages().get(1).guid());
assertEquals(response.messages().get(1).guid(), messageGuidTwo);
assertEquals(response.messages().get(0).sourceUuid(), sourceUuid);
assertEquals(response.messages().get(1).sourceUuid(), sourceUuid);
@@ -411,14 +419,13 @@ class MessageControllerTest {
final long timestampOne = 313377;
final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<>() {{
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0));
}};
final List<Envelope> messages = List.of(
generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0),
generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)
);
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
Response response =
resources.getJerseyTest().target("/v1/messages/")
@@ -437,12 +444,12 @@ class MessageControllerTest {
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(new OutgoingMessageEntity(
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(generateEnvelope(
uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(new OutgoingMessageEntity(
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)));
@@ -624,4 +631,34 @@ class MessageControllerTest {
Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false)
);
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, String source, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setServerGuid(guid.toString());
if (StringUtils.isNotEmpty(source)) {
builder.setSource(source)
.setSourceDevice(sourceDevice);
if (sourceUuid != null) {
builder.setSourceUuid(sourceUuid.toString());
}
}
if (content != null) {
builder.setContent(ByteString.copyFrom(content));
}
if (updatedPni != null) {
builder.setUpdatedPni(updatedPni.toString());
}
return builder.build();
}
}

View File

@@ -34,7 +34,6 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheTest {
@@ -103,11 +102,10 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
final Optional<MessageProtos.Envelope> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get());
assertEquals(Optional.of(message), maybeRemovedMessage);
}
@ParameterizedTest
@@ -135,14 +133,11 @@ class MessagesCacheTest {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
}
final List<OutgoingMessageEntity> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
final List<MessageProtos.Envelope> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList()));
assertEquals(messagesToRemove.stream().map(MessagesCache::constructEntityFromEnvelope)
.collect(Collectors.toList()),
removedMessages);
assertEquals(messagesToRemove, removedMessages);
assertEquals(messagesToPreserve,
messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@@ -163,14 +158,14 @@ class MessagesCacheTest {
void testGetMessages(final boolean sealedSender) {
final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
expectedMessages.add(MessagesCache.constructEntityFromEnvelope(message));
expectedMessages.add(message);
}
assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));

View File

@@ -83,15 +83,15 @@ class MessagesDynamoDbTest {
final int destinationDeviceId = random.nextInt(255) + 1;
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<OutgoingMessageEntity> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
final List<MessageProtos.Envelope> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
MessagesDynamoDb.RESULT_SET_CHUNK_SIZE);
assertThat(messagesStored).isNotNull().hasSize(3);
final MessageProtos.Envelope firstMessage =
MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3;
final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1;
assertThat(messagesStored).element(0).satisfies(verify(firstMessage));
assertThat(messagesStored).element(1).satisfies(verify(secondMessage));
assertThat(messagesStored).element(2).satisfies(verify(MESSAGE2));
assertThat(messagesStored).element(0).isEqualTo(firstMessage);
assertThat(messagesStored).element(1).isEqualTo(secondMessage);
assertThat(messagesStored).element(2).isEqualTo(MESSAGE2);
}
@Test
@@ -103,18 +103,18 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@Test
@@ -126,19 +126,19 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@Test
@@ -150,19 +150,19 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid,
UUID.fromString(MESSAGE2.getServerGuid()));
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
@@ -176,50 +176,20 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp());
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
.element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
.element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) {
assertThat(retrieved.timestamp()).isEqualTo(inserted.getTimestamp());
assertThat(retrieved.source()).isEqualTo(inserted.hasSource() ? inserted.getSource() : null);
assertThat(retrieved.sourceUuid()).isEqualTo(inserted.hasSourceUuid() ? UUID.fromString(inserted.getSourceUuid()) : null);
assertThat(retrieved.sourceDevice()).isEqualTo(inserted.getSourceDevice());
assertThat(retrieved.type()).isEqualTo(inserted.getType().getNumber());
assertThat(retrieved.content()).isEqualTo(inserted.hasContent() ? inserted.getContent().toByteArray() : null);
assertThat(retrieved.serverTimestamp()).isEqualTo(inserted.getServerTimestamp());
assertThat(retrieved.guid()).isEqualTo(UUID.fromString(inserted.getServerGuid()));
assertThat(retrieved.destinationUuid()).isEqualTo(UUID.fromString(inserted.getDestinationUuid()));
}
private static VerifyMessage verify(MessageProtos.Envelope expected) {
return new VerifyMessage(expected);
}
private static final class VerifyMessage implements Consumer<OutgoingMessageEntity> {
private final MessageProtos.Envelope expected;
public VerifyMessage(MessageProtos.Envelope expected) {
this.expected = expected;
}
@Override
public void accept(OutgoingMessageEntity outgoingMessageEntity) {
verify(outgoingMessageEntity, expected);
}
}
}

View File

@@ -23,15 +23,17 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials;
import io.lettuce.core.RedisException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
@@ -49,7 +51,6 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@@ -111,14 +112,10 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{
put("login", new LinkedList<>() {{
add(VALID_USER);
}});
put("password", new LinkedList<>() {{
add(VALID_PASSWORD);
}});
}});
when(upgradeRequest.getParameterMap()).thenReturn(Map.of(
"login", List.of(VALID_USER),
"password", List.of(VALID_PASSWORD)));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
@@ -127,14 +124,10 @@ class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{
add(INVALID_USER);
}});
put("password", new LinkedList<String>() {{
add(INVALID_PASSWORD);
}});
}});
when(upgradeRequest.getParameterMap()).thenReturn(Map.of(
"login", List.of(INVALID_USER),
"password", List.of(INVALID_PASSWORD)
));
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
@@ -149,13 +142,9 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first"));
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, "second"));
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
List<Envelope> outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, "second"),
createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
when(device.getId()).thenReturn(2L);
@@ -175,7 +164,7 @@ class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
@@ -207,7 +196,7 @@ class WebSocketConnectionTest {
futures.get(0).completeExceptionally(new IOException());
futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).guid()), eq(outgoingMessages.get(1).serverTimestamp()));
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq(senderOneUuid), eq(2222L));
connection.stop();
@@ -229,9 +218,9 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")), false));
.thenReturn(new Pair<>(Collections.emptyList(), false))
.thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false))
.thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -282,36 +271,27 @@ class WebSocketConnectionTest {
final UUID senderTwoUuid = UUID.randomUUID();
final Envelope firstMessage = Envelope.newBuilder()
.setSource("sender1")
.setSourceUuid(UUID.randomUUID().toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
.build();
.setServerGuid(UUID.randomUUID().toString())
.setSource("sender1")
.setSourceUuid(UUID.randomUUID().toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setUpdatedPni(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(Envelope.Type.CIPHERTEXT)
.build();
final Envelope secondMessage = Envelope.newBuilder()
.setSource("sender2")
.setSourceUuid(senderTwoUuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT)
.build();
.setServerGuid(UUID.randomUUID().toString())
.setSource("sender2")
.setSourceUuid(senderTwoUuid.toString())
.setDestinationUuid(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(Envelope.Type.CIPHERTEXT)
.build();
List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(UUID.randomUUID(), firstMessage.getType().getNumber(),
firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()),
firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), UUID.fromString(firstMessage.getUpdatedPni()),
firstMessage.getContent().toByteArray(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), secondMessage.getType().getNumber(),
secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()),
secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), null,
secondMessage.getContent().toByteArray(), 0));
}};
OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false);
final List<Envelope> pendingMessages = List.of(firstMessage, secondMessage);
when(device.getId()).thenReturn(2L);
@@ -331,20 +311,17 @@ class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList);
.thenReturn(new Pair<>(pendingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()))
.thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>) invocationOnMock -> {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
@@ -352,8 +329,7 @@ class WebSocketConnectionTest {
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any());
assertEquals(futures.size(), 2);
@@ -446,19 +422,16 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second"));
final List<Envelope> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
final List<Envelope> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(firstPage)
.thenReturn(secondPage);
.thenReturn(new Pair<>(firstPageMessages, true))
.thenReturn(new Pair<>(secondPageMessages, false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -493,11 +466,11 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of(
createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, false, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
final List<Envelope> messages = List.of(
createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(new Pair<>(messages, false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -549,7 +522,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -577,20 +550,17 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second"));
final List<Envelope> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
final List<Envelope> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(firstPage)
.thenReturn(secondPage)
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new Pair<>(firstPageMessages, false))
.thenReturn(new Pair<>(secondPageMessages, false))
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -629,7 +599,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -662,7 +632,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -685,13 +655,11 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first"));
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)));
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
List<Envelope> outgoingMessages = List.of(
createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222,
RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
when(device.getId()).thenReturn(2L);
@@ -711,7 +679,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Desktop/1.2.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
@@ -758,13 +726,10 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first"));
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)));
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
List<Envelope> outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222,
RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
when(device.getId()).thenReturn(2L);
@@ -784,7 +749,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
@@ -883,9 +848,18 @@ class WebSocketConnectionTest {
verify(client, never()).close(anyInt(), anyString());
}
private OutgoingMessageEntity createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(UUID.randomUUID(), receipt ? Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sender, senderUuid, 1, destinationUuid, null, content.getBytes(), 0);
private Envelope createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, String content) {
return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
.setType(Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setServerTimestamp(0)
.setSource(sender)
.setSourceUuid(senderUuid.toString())
.setSourceDevice(1)
.setDestinationUuid(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
}
}