mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-21 02:18:08 +01:00
Internalize destination device list/registration ID checks in MessageSender
This commit is contained in:
@@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.anyBoolean;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.reset;
|
||||
@@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.SpamReport;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
@@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class MessageControllerTest {
|
||||
@@ -195,7 +197,7 @@ class MessageControllerTest {
|
||||
.build();
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
void setup() throws MultiRecipientMismatchedDevicesException {
|
||||
reset(pushNotificationScheduler);
|
||||
|
||||
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
@@ -287,7 +289,7 @@ class MessageControllerTest {
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
@@ -332,7 +334,7 @@ class MessageControllerTest {
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
@@ -357,7 +359,7 @@ class MessageControllerTest {
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
@@ -395,7 +397,7 @@ class MessageControllerTest {
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
@@ -434,7 +436,7 @@ class MessageControllerTest {
|
||||
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
|
||||
if (expectedResponse == 200) {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
@@ -530,6 +532,9 @@ class MessageControllerTest {
|
||||
|
||||
@Test
|
||||
void testMultiDeviceMissing() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
@@ -542,15 +547,16 @@ class MessageControllerTest {
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(MismatchedDevices.class)),
|
||||
asJson(response.readEntity(MismatchedDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiDeviceExtra() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
@@ -563,10 +569,8 @@ class MessageControllerTest {
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(MismatchedDevices.class)),
|
||||
asJson(response.readEntity(MismatchedDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -602,7 +606,7 @@ class MessageControllerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(3, envelopeCaptor.getValue().size());
|
||||
|
||||
@@ -626,7 +630,7 @@ class MessageControllerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(3, envelopeCaptor.getValue().size());
|
||||
|
||||
@@ -648,12 +652,17 @@ class MessageControllerTest {
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class),
|
||||
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3));
|
||||
any(),
|
||||
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
|
||||
any());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRegistrationIdMismatch() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
@@ -665,10 +674,8 @@ class MessageControllerTest {
|
||||
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(StaleDevices.class)),
|
||||
asJson(response.readEntity(StaleDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1078,7 +1085,7 @@ class MessageControllerTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void testValidateContentLength() {
|
||||
void testValidateContentLength() throws MismatchedDevicesException {
|
||||
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
|
||||
final byte[] contentBytes = new byte[contentLength];
|
||||
Arrays.fill(contentBytes, (byte) 1);
|
||||
@@ -1095,7 +1102,7 @@ class MessageControllerTest {
|
||||
|
||||
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
|
||||
|
||||
verify(messageSender, never()).sendMessages(any(), any());
|
||||
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1113,10 +1120,10 @@ class MessageControllerTest {
|
||||
|
||||
if (expectOk) {
|
||||
assertEquals(200, response.getStatus());
|
||||
verify(messageSender).sendMessages(any(), any());
|
||||
verify(messageSender).sendMessages(any(), any(), any(), any());
|
||||
} else {
|
||||
assertEquals(422, response.getStatus());
|
||||
verify(messageSender, never()).sendMessages(any(), any());
|
||||
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1140,7 +1147,9 @@ class MessageControllerTest {
|
||||
final Optional<String> maybeGroupSendToken,
|
||||
final int expectedStatus,
|
||||
final Set<Account> expectedResolvedAccounts,
|
||||
final Set<ServiceIdentifier> expectedUuids404) {
|
||||
final Set<ServiceIdentifier> expectedUuids404,
|
||||
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
|
||||
throws MultiRecipientMismatchedDevicesException {
|
||||
|
||||
clock.pin(START_OF_DAY);
|
||||
|
||||
@@ -1151,6 +1160,11 @@ class MessageControllerTest {
|
||||
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
|
||||
|
||||
if (mismatchedDevicesException != null) {
|
||||
doThrow(mismatchedDevicesException)
|
||||
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
|
||||
}
|
||||
|
||||
final boolean ephemeral = true;
|
||||
final boolean urgent = false;
|
||||
|
||||
@@ -1187,7 +1201,7 @@ class MessageControllerTest {
|
||||
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
|
||||
}
|
||||
|
||||
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) {
|
||||
if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
|
||||
verify(messageSender).sendMultiRecipientMessage(any(),
|
||||
argThat(resolvedRecipients ->
|
||||
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
|
||||
@@ -1267,7 +1281,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message with combined UAKs",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1279,7 +1294,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1291,7 +1307,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Incorrect combined UAK",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1303,7 +1320,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Incorrect group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1317,7 +1335,8 @@ class MessageControllerTest {
|
||||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
|
||||
// provided
|
||||
@@ -1331,7 +1350,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1343,7 +1363,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Conflicting credentials",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1355,7 +1376,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("No credentials",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1367,7 +1389,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Oversized payload",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1383,7 +1406,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
413,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Negative timestamp",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1395,7 +1419,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Excessive timestamp",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1407,7 +1432,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Empty recipient list",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1421,7 +1447,8 @@ class MessageControllerTest {
|
||||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story with empty recipient list",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1433,7 +1460,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Duplicate recipient",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1447,7 +1475,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing account",
|
||||
Map.of(),
|
||||
@@ -1459,7 +1488,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Collections.emptySet(),
|
||||
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))),
|
||||
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("One missing and one existing account",
|
||||
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
|
||||
@@ -1471,7 +1501,8 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Set.of(singleDeviceAccount),
|
||||
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))),
|
||||
Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing account for story",
|
||||
Map.of(),
|
||||
@@ -1483,7 +1514,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Collections.emptySet(),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("One missing and one existing account for story",
|
||||
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
|
||||
@@ -1495,7 +1527,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing device",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1509,7 +1542,9 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
409,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
|
||||
|
||||
Arguments.argumentSet("Extra device",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1525,7 +1560,9 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
409,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
|
||||
|
||||
Arguments.argumentSet("Stale registration ID",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1540,7 +1577,9 @@ class MessageControllerTest {
|
||||
Optional.of(groupSendEndorsement),
|
||||
410,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
|
||||
|
||||
Arguments.argumentSet("Rate-limited story",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1552,7 +1591,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
429,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story to PNI recipients",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1567,7 +1607,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1582,7 +1623,8 @@ class MessageControllerTest {
|
||||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
@@ -1599,7 +1641,8 @@ class MessageControllerTest {
|
||||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of())
|
||||
Set.of(),
|
||||
null)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
|
||||
|
||||
class MessageSenderTest {
|
||||
|
||||
@@ -60,7 +73,9 @@ class MessageSenderTest {
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
@@ -71,7 +86,11 @@ class MessageSenderTest {
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
@@ -82,7 +101,10 @@ class MessageSenderTest {
|
||||
|
||||
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId)));
|
||||
|
||||
final MessageProtos.Envelope expectedMessage = ephemeral
|
||||
? message.toBuilder().setEphemeral(true).build()
|
||||
@@ -97,23 +119,61 @@ class MessageSenderTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMessageMismatchedDevices() {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final MismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId + 1)));
|
||||
|
||||
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
|
||||
mismatchedDevicesException.getMismatchedDevices());
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
|
||||
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
|
||||
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
@@ -125,12 +185,19 @@ class MessageSenderTest {
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
|
||||
Collections.emptyMap(),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
.join());
|
||||
|
||||
if (expectPushNotificationAttempt) {
|
||||
@@ -140,6 +207,49 @@ class MessageSenderTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
|
||||
|
||||
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MultiRecipientMismatchedDevicesException.class,
|
||||
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
false,
|
||||
true)
|
||||
.join());
|
||||
|
||||
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
|
||||
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getDeliveryChannelName(final Device device, final String expectedChannelName) {
|
||||
@@ -183,4 +293,87 @@ class MessageSenderTest {
|
||||
assertDoesNotThrow(() ->
|
||||
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getMismatchedDevices(final Account account,
|
||||
final ServiceIdentifier serviceIdentifier,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
final byte excludedDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
|
||||
|
||||
assertEquals(expectedMismatchedDevices,
|
||||
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
|
||||
}
|
||||
|
||||
private static List<Arguments> getMismatchedDevices() {
|
||||
final byte primaryDeviceId = Device.PRIMARY_ID;
|
||||
final byte linkedDeviceId = primaryDeviceId + 1;
|
||||
final byte extraDeviceId = linkedDeviceId + 1;
|
||||
|
||||
final int primaryDeviceAciRegistrationId = 2;
|
||||
final int primaryDevicePniRegistrationId = 3;
|
||||
final int linkedDeviceAciRegistrationId = 5;
|
||||
final int linkedDevicePniRegistrationId = 7;
|
||||
|
||||
final Device primaryDevice = mock(Device.class);
|
||||
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
|
||||
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
|
||||
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
|
||||
|
||||
final Device linkedDevice = mock(Device.class);
|
||||
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
|
||||
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
|
||||
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
|
||||
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
|
||||
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
|
||||
|
||||
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
|
||||
|
||||
return List.of(
|
||||
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDeviceAciRegistrationId,
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
|
||||
account,
|
||||
pniServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDevicePniRegistrationId,
|
||||
linkedDeviceId, linkedDevicePniRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list, device excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
primaryDeviceId,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Mismatched devices",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
|
||||
extraDeviceId, 17
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,12 @@ import java.util.function.Consumer;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
@@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
|
||||
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
@@ -1705,4 +1708,47 @@ class AccountsManagerTest {
|
||||
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds, @Nullable final MismatchedDevicesException expectedException) {
|
||||
final Executable validateCompleteDeviceListExecutable =
|
||||
() -> AccountsManager.validateCompleteDeviceList(account, deviceIds);
|
||||
|
||||
if (expectedException != null) {
|
||||
final MismatchedDevicesException caughtException =
|
||||
assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable);
|
||||
|
||||
assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices());
|
||||
} else {
|
||||
assertDoesNotThrow(validateCompleteDeviceListExecutable);
|
||||
}
|
||||
}
|
||||
|
||||
private static List<Arguments> validateCompleteDeviceList() {
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final byte extraDeviceId = deviceId + 1;
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
|
||||
return List.of(
|
||||
Arguments.of(account, Set.of(deviceId), null),
|
||||
|
||||
Arguments.of(account, Set.of(deviceId, extraDeviceId),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))),
|
||||
|
||||
Arguments.of(account, Collections.emptySet(),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))),
|
||||
|
||||
Arguments.of(account, Set.of(extraDeviceId),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
|
||||
@@ -105,7 +106,7 @@ public class ChangeNumberManagerTest {
|
||||
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
|
||||
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -119,7 +120,7 @@ public class ChangeNumberManagerTest {
|
||||
|
||||
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -159,7 +160,7 @@ public class ChangeNumberManagerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
@@ -212,7 +213,7 @@ public class ChangeNumberManagerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
@@ -263,7 +264,7 @@ public class ChangeNumberManagerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
@@ -310,7 +311,7 @@ public class ChangeNumberManagerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
@@ -359,7 +360,7 @@ public class ChangeNumberManagerTest {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
@@ -372,82 +373,6 @@ public class ChangeNumberManagerTest {
|
||||
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
|
||||
}
|
||||
|
||||
@Test
|
||||
void changeNumberMismatchedRegistrationId() {
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
for (byte i = 1; i <= 3; i++) {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(i);
|
||||
when(device.getRegistrationId()).thenReturn((int) i);
|
||||
|
||||
devices.add(device);
|
||||
when(account.getDevice(i)).thenReturn(Optional.of(device));
|
||||
}
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
||||
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
|
||||
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
|
||||
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
|
||||
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
|
||||
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
|
||||
destinationDeviceId3, 89);
|
||||
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void updatePniKeysMismatchedRegistrationId() {
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
for (byte i = 1; i <= 3; i++) {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(i);
|
||||
when(device.getRegistrationId()).thenReturn((int) i);
|
||||
|
||||
devices.add(device);
|
||||
when(account.getDevice(i)).thenReturn(Optional.of(device));
|
||||
}
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
||||
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
|
||||
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
|
||||
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
|
||||
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
|
||||
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
|
||||
destinationDeviceId3, 89);
|
||||
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void changeNumberMissingData() {
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
/*
|
||||
* Copyright 2013-2022 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
import org.assertj.core.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class DestinationDeviceValidatorTest {
|
||||
|
||||
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
});
|
||||
|
||||
return account;
|
||||
}
|
||||
|
||||
static Stream<Arguments> validateRegistrationIdsSource() {
|
||||
final byte id1 = 1;
|
||||
final byte id2 = 2;
|
||||
final byte id3 = 3;
|
||||
return Stream.of(
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
|
||||
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 1492),
|
||||
Set.of(id1)),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 42),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 0),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
|
||||
Map.of(id1, 0, id2, 42),
|
||||
Set.of(id2)),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
|
||||
Map.of(id1, 41, id2, 257),
|
||||
Set.of(id1, id2))
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("validateRegistrationIdsSource")
|
||||
void testValidateRegistrationIds(
|
||||
Account account,
|
||||
Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
Set<Byte> expectedStaleDeviceIds) throws Exception {
|
||||
if (expectedStaleDeviceIds != null) {
|
||||
Assertions.assertThat(assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
registrationIdsByDeviceId.entrySet(),
|
||||
Map.Entry::getKey,
|
||||
Map.Entry::getValue,
|
||||
false))
|
||||
.getStaleDevices())
|
||||
.hasSameElementsAs(expectedStaleDeviceIds);
|
||||
} else {
|
||||
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
|
||||
Map.Entry::getKey, Map.Entry::getValue, false);
|
||||
}
|
||||
}
|
||||
|
||||
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
|
||||
final Account account = mock(Account.class);
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
||||
devices.add(device);
|
||||
});
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
return account;
|
||||
}
|
||||
|
||||
static Stream<Arguments> validateCompleteDeviceList() {
|
||||
final byte id1 = 1;
|
||||
final byte id2 = 2;
|
||||
final byte id3 = 3;
|
||||
|
||||
final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true));
|
||||
|
||||
return Stream.of(
|
||||
// Device IDs provided for all enabled devices
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id3),
|
||||
Set.of(id2),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID provided for disabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id2, id3),
|
||||
null,
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID omitted for enabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1),
|
||||
Set.of(id2, id3),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID included for disabled device, omitted for enabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id2),
|
||||
Set.of(id3),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID omitted for enabled device, included for device in excluded list
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1),
|
||||
Set.of(id2, id3),
|
||||
Set.of(id1),
|
||||
Set.of(id1)
|
||||
),
|
||||
|
||||
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id2),
|
||||
Set.of(id3),
|
||||
null,
|
||||
Set.of(id1)
|
||||
),
|
||||
|
||||
// Device ID included for enabled device, omitted for excluded device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id3),
|
||||
Set.of(id2),
|
||||
null,
|
||||
Set.of(id1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void validateCompleteDeviceList(
|
||||
Account account,
|
||||
Set<Byte> deviceIds,
|
||||
Collection<Byte> expectedMissingDeviceIds,
|
||||
Collection<Byte> expectedExtraDeviceIds,
|
||||
Set<Byte> excludedDeviceIds) throws Exception {
|
||||
|
||||
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
|
||||
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
|
||||
if (expectedMissingDeviceIds != null) {
|
||||
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
|
||||
.hasSameElementsAs(expectedMissingDeviceIds);
|
||||
}
|
||||
if (expectedExtraDeviceIds != null) {
|
||||
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
|
||||
}
|
||||
} else {
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDuplicateDeviceIds() {
|
||||
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17));
|
||||
try {
|
||||
DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false);
|
||||
Assertions.fail("duplicate devices should throw StaleDevicesException");
|
||||
} catch (StaleDevicesException e) {
|
||||
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testValidatePniRegistrationIds() {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
|
||||
|
||||
final int aciRegistrationId = 17;
|
||||
final int pniRegistrationId = 89;
|
||||
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
|
||||
|
||||
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
|
||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
|
||||
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false));
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
|
||||
false));
|
||||
|
||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
|
||||
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
false));
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true));
|
||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user