Internalize destination device list/registration ID checks in MessageSender

This commit is contained in:
Jon Chambers
2025-04-07 09:15:39 -04:00
committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
21 changed files with 675 additions and 755 deletions

View File

@@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import javax.annotation.Nullable;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@@ -195,7 +197,7 @@ class MessageControllerTest {
.build();
@BeforeEach
void setup() {
void setup() throws MultiRecipientMismatchedDevicesException {
reset(pushNotificationScheduler);
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
@@ -287,7 +289,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -332,7 +334,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -357,7 +359,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -395,7 +397,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -434,7 +436,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -530,6 +532,9 @@ class MessageControllerTest {
@Test
void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@@ -542,15 +547,16 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@Test
void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@@ -563,10 +569,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@@ -602,7 +606,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@@ -626,7 +630,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@@ -648,12 +652,17 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(Account.class),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3));
any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any());
}
}
@Test
void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request()
@@ -665,10 +674,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
assertThat("Good Response Body",
asJson(response.readEntity(StaleDevices.class)),
asJson(response.readEntity(StaleDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@@ -1078,7 +1085,7 @@ class MessageControllerTest {
}
@Test
void testValidateContentLength() {
void testValidateContentLength() throws MismatchedDevicesException {
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
@@ -1095,7 +1102,7 @@ class MessageControllerTest {
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
@@ -1113,10 +1120,10 @@ class MessageControllerTest {
if (expectOk) {
assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any());
verify(messageSender).sendMessages(any(), any(), any(), any());
} else {
assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
}
@@ -1140,7 +1147,9 @@ class MessageControllerTest {
final Optional<String> maybeGroupSendToken,
final int expectedStatus,
final Set<Account> expectedResolvedAccounts,
final Set<ServiceIdentifier> expectedUuids404) {
final Set<ServiceIdentifier> expectedUuids404,
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
throws MultiRecipientMismatchedDevicesException {
clock.pin(START_OF_DAY);
@@ -1151,6 +1160,11 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
if (mismatchedDevicesException != null) {
doThrow(mismatchedDevicesException)
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
}
final boolean ephemeral = true;
final boolean urgent = false;
@@ -1187,7 +1201,7 @@ class MessageControllerTest {
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
}
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) {
if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
verify(messageSender).sendMultiRecipientMessage(any(),
argThat(resolvedRecipients ->
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
@@ -1267,7 +1281,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with combined UAKs",
accountsByServiceIdentifier,
@@ -1279,7 +1294,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with group send endorsement",
accountsByServiceIdentifier,
@@ -1291,7 +1307,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect combined UAK",
accountsByServiceIdentifier,
@@ -1303,7 +1320,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect group send endorsement",
accountsByServiceIdentifier,
@@ -1317,7 +1335,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
// provided
@@ -1331,7 +1350,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with group send endorsement",
accountsByServiceIdentifier,
@@ -1343,7 +1363,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Conflicting credentials",
accountsByServiceIdentifier,
@@ -1355,7 +1376,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("No credentials",
accountsByServiceIdentifier,
@@ -1367,7 +1389,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Oversized payload",
accountsByServiceIdentifier,
@@ -1383,7 +1406,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
413,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Negative timestamp",
accountsByServiceIdentifier,
@@ -1395,7 +1419,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Excessive timestamp",
accountsByServiceIdentifier,
@@ -1407,7 +1432,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Empty recipient list",
accountsByServiceIdentifier,
@@ -1421,7 +1447,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with empty recipient list",
accountsByServiceIdentifier,
@@ -1433,7 +1460,8 @@ class MessageControllerTest {
Optional.empty(),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Duplicate recipient",
accountsByServiceIdentifier,
@@ -1447,7 +1475,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing account",
Map.of(),
@@ -1459,7 +1488,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Collections.emptySet(),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("One missing and one existing account",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@@ -1471,7 +1501,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("Missing account for story",
Map.of(),
@@ -1483,7 +1514,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Collections.emptySet(),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("One missing and one existing account for story",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@@ -1495,7 +1527,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing device",
accountsByServiceIdentifier,
@@ -1509,7 +1542,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
Arguments.argumentSet("Extra device",
accountsByServiceIdentifier,
@@ -1525,7 +1560,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
Arguments.argumentSet("Stale registration ID",
accountsByServiceIdentifier,
@@ -1540,7 +1577,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
410,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
Arguments.argumentSet("Rate-limited story",
accountsByServiceIdentifier,
@@ -1552,7 +1591,8 @@ class MessageControllerTest {
Optional.empty(),
429,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story to PNI recipients",
accountsByServiceIdentifier,
@@ -1567,7 +1607,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
accountsByServiceIdentifier,
@@ -1582,7 +1623,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
accountsByServiceIdentifier,
@@ -1599,7 +1641,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of())
Set.of(),
null)
);
}