Add support for setting PNI-associated registration IDs and identity keys when changing numbers

This commit is contained in:
Jon Chambers
2022-07-26 15:19:27 -04:00
committed by GitHub
parent c252118cfc
commit dce391a248
26 changed files with 927 additions and 673 deletions

View File

@@ -253,9 +253,10 @@ class AccountControllerTest {
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername"))
.thenThrow(new UsernameNotAvailableException());
when(changeNumberManager.changeNumber(any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final String pniIdentityKey = invocation.getArgument(2, String.class);
final UUID uuid = account.getUuid();
final List<Device> devices = account.getDevices();
@@ -263,8 +264,10 @@ class AccountControllerTest {
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
@@ -1298,10 +1301,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
@@ -1318,12 +1321,12 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.readEntity(String.class)).isBlank();
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1336,7 +1339,7 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
@@ -1345,7 +1348,7 @@ class AccountControllerTest {
assertThat(responseEntity.getOriginalNumber()).isEqualTo(number);
assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111");
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1355,10 +1358,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
@@ -1373,11 +1376,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1393,11 +1396,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code + "-incorrect", null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1423,11 +1426,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
@@ -1453,11 +1456,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1485,11 +1488,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
}
@Test
@@ -1517,82 +1520,33 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any());
}
@Test
void testChangePhoneNumberDeviceMessagesWithoutPrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null,
List.of(new IncomingMessage(1, null, 1, 1, "foo")), null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchDeviceIDs() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 4, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
}
@Test
void testChangePhoneNumberChangePrekeys() throws Exception {
final String number = "+18005559876";
final String code = "987654";
final String pniIdentityKey = "changed-pni-identity-key";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
@@ -1601,6 +1555,8 @@ class AccountControllerTest {
new IncomingMessage(1, null, 3, 3, "content3"));
var deviceKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v1/accounts/number")
@@ -1608,53 +1564,18 @@ class AccountControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
deviceMessages,
deviceKeys),
pniIdentityKey, deviceMessages,
deviceKeys,
registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
assertThat(accountIdentityResponse.getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.getNumber()).isEqualTo(number);
assertThat(accountIdentityResponse.getPni()).isNotEqualTo(AuthHelper.VALID_PNI);
}
@Test
void testChangePhoneNumberChangePrekeysDeviceMessagesMismatchRegistrationID() throws Exception {
final String number = "+18005559876";
final String code = "987654";
Device device2 = mock(Device.class);
when(device2.getId()).thenReturn(2L);
when(device2.isEnabled()).thenReturn(true);
when(device2.getRegistrationId()).thenReturn(2);
Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
final Response response =
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
List.of(
new IncomingMessage(1, null, 2, 1, "foo"),
new IncomingMessage(1, null, 3, 1, "foo")),
Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey())),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(410);
verify(accountsManager, never()).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void testSetRegistrationLock() {
Response response =

View File

@@ -28,6 +28,7 @@ import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
@@ -73,6 +74,8 @@ class KeysControllerTest {
private static final int SAMPLE_REGISTRATION_ID2 = 1002;
private static final int SAMPLE_REGISTRATION_ID4 = 1555;
private static final int SAMPLE_PNI_REGISTRATION_ID = 1717;
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
@@ -106,9 +109,11 @@ class KeysControllerTest {
.addResource(new RateLimitExceededExceptionMapper())
.build();
private Device sampleDevice;
@BeforeEach
void setup() {
final Device sampleDevice = mock(Device.class);
sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
@@ -121,6 +126,7 @@ class KeysControllerTest {
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(SAMPLE_PNI_REGISTRATION_ID));
when(sampleDevice.isEnabled()).thenReturn(true);
when(sampleDevice2.isEnabled()).thenReturn(true);
when(sampleDevice3.isEnabled()).thenReturn(false);
@@ -284,6 +290,7 @@ class KeysControllerTest {
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1);
@@ -302,6 +309,28 @@ class KeysControllerTest {
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() {
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);

View File

@@ -1,227 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.MessageValidation;
import org.whispersystems.textsecuregcm.util.Pair;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageValidationTest {
static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) {
Account account = mock(Account.class);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
}
return account;
}
static Collection<Pair<Long, Integer>> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) {
final Collection<Pair<Long, Integer>> result = new HashSet<>(deviceAndRegistrationIds.length);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
result.add(new Pair<>(deviceId, registrationId));
}
return result;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 255),
deviceAndRegistrationIds(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 256),
deviceAndRegistrationIds(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Collection<Pair<Long, Integer>> deviceAndRegistrationIds,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds);
} else {
MessageValidation.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}
}
static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) {
Account account = mock(Account.class);
if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
final List<Device> devices = new ArrayList<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) {
throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1));
}
Long deviceId = (Long) deviceIdAndEnabled[i];
Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1];
Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
}
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageValidation.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
}
}
}