Internalize destination device list/registration ID checks in MessageSender

This commit is contained in:
Jon Chambers
2025-04-07 09:15:39 -04:00
committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
21 changed files with 675 additions and 755 deletions

View File

@@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import javax.annotation.Nullable;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@@ -195,7 +197,7 @@ class MessageControllerTest {
.build();
@BeforeEach
void setup() {
void setup() throws MultiRecipientMismatchedDevicesException {
reset(pushNotificationScheduler);
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
@@ -287,7 +289,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -332,7 +334,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -357,7 +359,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -395,7 +397,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -434,7 +436,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@@ -530,6 +532,9 @@ class MessageControllerTest {
@Test
void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@@ -542,15 +547,16 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@Test
void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@@ -563,10 +569,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@@ -602,7 +606,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@@ -626,7 +630,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@@ -648,12 +652,17 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(Account.class),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3));
any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any());
}
}
@Test
void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request()
@@ -665,10 +674,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
assertThat("Good Response Body",
asJson(response.readEntity(StaleDevices.class)),
asJson(response.readEntity(StaleDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@@ -1078,7 +1085,7 @@ class MessageControllerTest {
}
@Test
void testValidateContentLength() {
void testValidateContentLength() throws MismatchedDevicesException {
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
@@ -1095,7 +1102,7 @@ class MessageControllerTest {
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
@@ -1113,10 +1120,10 @@ class MessageControllerTest {
if (expectOk) {
assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any());
verify(messageSender).sendMessages(any(), any(), any(), any());
} else {
assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
}
@@ -1140,7 +1147,9 @@ class MessageControllerTest {
final Optional<String> maybeGroupSendToken,
final int expectedStatus,
final Set<Account> expectedResolvedAccounts,
final Set<ServiceIdentifier> expectedUuids404) {
final Set<ServiceIdentifier> expectedUuids404,
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
throws MultiRecipientMismatchedDevicesException {
clock.pin(START_OF_DAY);
@@ -1151,6 +1160,11 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
if (mismatchedDevicesException != null) {
doThrow(mismatchedDevicesException)
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
}
final boolean ephemeral = true;
final boolean urgent = false;
@@ -1187,7 +1201,7 @@ class MessageControllerTest {
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
}
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) {
if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
verify(messageSender).sendMultiRecipientMessage(any(),
argThat(resolvedRecipients ->
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
@@ -1267,7 +1281,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with combined UAKs",
accountsByServiceIdentifier,
@@ -1279,7 +1294,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with group send endorsement",
accountsByServiceIdentifier,
@@ -1291,7 +1307,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect combined UAK",
accountsByServiceIdentifier,
@@ -1303,7 +1320,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect group send endorsement",
accountsByServiceIdentifier,
@@ -1317,7 +1335,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
// provided
@@ -1331,7 +1350,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with group send endorsement",
accountsByServiceIdentifier,
@@ -1343,7 +1363,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Conflicting credentials",
accountsByServiceIdentifier,
@@ -1355,7 +1376,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("No credentials",
accountsByServiceIdentifier,
@@ -1367,7 +1389,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Oversized payload",
accountsByServiceIdentifier,
@@ -1383,7 +1406,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
413,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Negative timestamp",
accountsByServiceIdentifier,
@@ -1395,7 +1419,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Excessive timestamp",
accountsByServiceIdentifier,
@@ -1407,7 +1432,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Empty recipient list",
accountsByServiceIdentifier,
@@ -1421,7 +1447,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with empty recipient list",
accountsByServiceIdentifier,
@@ -1433,7 +1460,8 @@ class MessageControllerTest {
Optional.empty(),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Duplicate recipient",
accountsByServiceIdentifier,
@@ -1447,7 +1475,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing account",
Map.of(),
@@ -1459,7 +1488,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Collections.emptySet(),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("One missing and one existing account",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@@ -1471,7 +1501,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("Missing account for story",
Map.of(),
@@ -1483,7 +1514,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Collections.emptySet(),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("One missing and one existing account for story",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@@ -1495,7 +1527,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing device",
accountsByServiceIdentifier,
@@ -1509,7 +1542,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
Arguments.argumentSet("Extra device",
accountsByServiceIdentifier,
@@ -1525,7 +1560,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
Arguments.argumentSet("Stale registration ID",
accountsByServiceIdentifier,
@@ -1540,7 +1577,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
410,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
Arguments.argumentSet("Rate-limited story",
accountsByServiceIdentifier,
@@ -1552,7 +1591,8 @@ class MessageControllerTest {
Optional.empty(),
429,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story to PNI recipients",
accountsByServiceIdentifier,
@@ -1567,7 +1607,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
accountsByServiceIdentifier,
@@ -1582,7 +1623,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
accountsByServiceIdentifier,
@@ -1599,7 +1641,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of())
Set.of(),
null)
);
}

View File

@@ -22,6 +22,9 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
@@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
class MessageSenderTest {
@@ -60,7 +73,9 @@ class MessageSenderTest {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
@@ -71,7 +86,11 @@ class MessageSenderTest {
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
@@ -82,7 +101,10 @@ class MessageSenderTest {
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId)));
final MessageProtos.Envelope expectedMessage = ephemeral
? message.toBuilder().setEphemeral(true).build()
@@ -97,23 +119,61 @@ class MessageSenderTest {
}
}
@Test
void sendMessageMismatchedDevices() {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final MismatchedDevicesException mismatchedDevicesException =
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId + 1)));
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
mismatchedDevicesException.getMismatchedDevices());
}
@CartesianTest
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
@@ -125,12 +185,19 @@ class MessageSenderTest {
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
Collections.emptyMap(),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
final SealedSenderMultiRecipientMessage.Recipient recipient =
multiRecipientMessage.getRecipients().values().iterator().next();
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
.join());
if (expectPushNotificationAttempt) {
@@ -140,6 +207,49 @@ class MessageSenderTest {
}
}
@Test
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
final SealedSenderMultiRecipientMessage.Recipient recipient =
multiRecipientMessage.getRecipients().values().iterator().next();
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
assertThrows(MultiRecipientMismatchedDevicesException.class,
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
false,
true)
.join());
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
}
@ParameterizedTest
@MethodSource
void getDeliveryChannelName(final Device device, final String expectedChannelName) {
@@ -183,4 +293,87 @@ class MessageSenderTest {
assertDoesNotThrow(() ->
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
}
@ParameterizedTest
@MethodSource
void getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
assertEquals(expectedMismatchedDevices,
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
}
private static List<Arguments> getMismatchedDevices() {
final byte primaryDeviceId = Device.PRIMARY_ID;
final byte linkedDeviceId = primaryDeviceId + 1;
final byte extraDeviceId = linkedDeviceId + 1;
final int primaryDeviceAciRegistrationId = 2;
final int primaryDevicePniRegistrationId = 3;
final int linkedDeviceAciRegistrationId = 5;
final int linkedDevicePniRegistrationId = 7;
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
return List.of(
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
account,
aciServiceIdentifier,
Map.of(
primaryDeviceId, primaryDeviceAciRegistrationId,
linkedDeviceId, linkedDeviceAciRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
account,
pniServiceIdentifier,
Map.of(
primaryDeviceId, primaryDevicePniRegistrationId,
linkedDeviceId, linkedDevicePniRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list, device excluded",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId
),
primaryDeviceId,
Optional.empty()),
Arguments.argumentSet("Mismatched devices",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
extraDeviceId, 17
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
);
}
}

View File

@@ -60,10 +60,12 @@ import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.spec.SecretKeySpec;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
@@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@@ -1705,4 +1708,47 @@ class AccountsManagerTest {
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds, @Nullable final MismatchedDevicesException expectedException) {
final Executable validateCompleteDeviceListExecutable =
() -> AccountsManager.validateCompleteDeviceList(account, deviceIds);
if (expectedException != null) {
final MismatchedDevicesException caughtException =
assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable);
assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices());
} else {
assertDoesNotThrow(validateCompleteDeviceListExecutable);
}
}
private static List<Arguments> validateCompleteDeviceList() {
final byte deviceId = Device.PRIMARY_ID;
final byte extraDeviceId = deviceId + 1;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
return List.of(
Arguments.of(account, Set.of(deviceId), null),
Arguments.of(account, Set.of(deviceId, extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))),
Arguments.of(account, Collections.emptySet(),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))),
Arguments.of(account, Set.of(extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
);
}
}

View File

@@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@@ -105,7 +106,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessages(eq(account), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
}
@Test
@@ -119,7 +120,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
}
@Test
@@ -159,7 +160,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@@ -212,7 +213,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@@ -263,7 +264,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@@ -310,7 +311,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@@ -359,7 +360,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@@ -372,82 +373,6 @@ public class ChangeNumberManagerTest {
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@Test
void changeNumberMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test
void updatePniKeysMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test
void changeNumberMissingData() {
final Account account = mock(Account.class);

View File

@@ -1,273 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ExtendWith(DropwizardExtensionsSupport.class)
class DestinationDeviceValidatorTest {
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
final Account account = mock(Account.class);
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
final Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
});
return account;
}
static Stream<Arguments> validateRegistrationIdsSource() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 1492),
Set.of(id1)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
Map.of(id1, 0, id2, 42),
Set.of(id2)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
Map.of(id1, 41, id2, 257),
Set.of(id1, id2))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Map<Byte, Integer> registrationIdsByDeviceId,
Set<Byte> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(
account,
registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey,
Map.Entry::getValue,
false))
.getStaleDevices())
.hasSameElementsAs(expectedStaleDeviceIds);
} else {
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey, Map.Entry::getValue, false);
}
}
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
});
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceList() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true));
return Stream.of(
// Device IDs provided for all enabled devices
arguments(
account,
Set.of(id1, id3),
Set.of(id2),
null,
Collections.emptySet()),
// Device ID provided for disabled device
arguments(
account,
Set.of(id1, id2, id3),
null,
null,
Collections.emptySet()),
// Device ID omitted for enabled device
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
null,
Collections.emptySet()),
// Device ID included for disabled device, omitted for enabled device
arguments(
account,
Set.of(id1, id2),
Set.of(id3),
null,
Collections.emptySet()),
// Device ID omitted for enabled device, included for device in excluded list
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
Set.of(id1),
Set.of(id1)
),
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
arguments(
account,
Set.of(id2),
Set.of(id3),
null,
Set.of(id1)
),
// Device ID included for enabled device, omitted for excluded device
arguments(
account,
Set.of(id3),
Set.of(id2),
null,
Set.of(id1)
)
);
}
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(
Account account,
Set<Byte> deviceIds,
Collection<Byte> expectedMissingDeviceIds,
Collection<Byte> expectedExtraDeviceIds,
Set<Byte> excludedDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
}
}
@Test
void testDuplicateDeviceIds() {
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17));
try {
DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false);
Assertions.fail("duplicate devices should throw StaleDevicesException");
} catch (StaleDevicesException e) {
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID));
}
}
@Test
void testValidatePniRegistrationIds() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
final int aciRegistrationId = 17;
final int pniRegistrationId = 89;
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
false));
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false));
}
}