Convert Device.id from long to byte

This commit is contained in:
Chris Eager
2023-10-24 18:58:13 -05:00
committed by Chris Eager
parent 7299067829
commit 6a428b4da9
112 changed files with 1292 additions and 1094 deletions

View File

@@ -43,7 +43,7 @@ import java.util.Set;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
@@ -89,7 +89,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
@@ -126,7 +126,8 @@ class AuthEnablementRefreshRequirementProviderTest {
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
IntStream.range(2, 4)
.forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId)));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
@@ -137,22 +138,22 @@ class AuthEnablementRefreshRequirementProviderTest {
@Test
void testBuildDevicesEnabled() {
final long disabledDeviceId = 3L;
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
LongStream.range(1, 5)
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(id);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Long, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
@@ -168,7 +169,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();
assert account.getPrimaryDevice().orElseThrow().isEnabled();
@@ -199,13 +200,16 @@ class AuthEnablementRefreshRequirementProviderTest {
}
static Stream<Arguments> testDeviceEnabledChanged() {
final byte deviceId1 = Device.PRIMARY_ID;
final byte deviceId2 = 2;
final byte deviceId3 = 3;
return Stream.of(
Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, false, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, false), Map.of(2L, true, 3L, true))
Arguments.of(Map.of(deviceId1, false, deviceId2, false), Map.of(deviceId1, true, deviceId2, false)),
Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, false, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, false), Map.of(deviceId2, true, deviceId3, true))
);
}
@@ -227,9 +231,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 3);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3);
}
@ParameterizedTest
@@ -237,13 +241,13 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
final List<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Long> deletedDeviceIds = account.getDevices().stream()
final List<Byte> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != 1L)
.filter(deviceId -> deviceId != Device.PRIMARY_ID)
.limit(removedDeviceCount)
.collect(Collectors.toList());
.toList();
assert deletedDeviceIds.size() == removedDeviceCount;
@@ -269,9 +273,9 @@ class AuthEnablementRefreshRequirementProviderTest {
void testPrimaryDeviceDisabledAndDeviceRemoved() {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
final Set<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final Set<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final long deletedDeviceId = 2L;
final byte deletedDeviceId = 2;
assertTrue(initialDeviceIds.remove(deletedDeviceId));
final Response response = resources.getJerseyTest()
@@ -427,11 +431,11 @@ class AuthEnablementRefreshRequirementProviderTest {
@POST
@Path("/account/devices/enabled")
@ChangesDeviceEnabledState
public String setEnabled(@Auth TestPrincipal principal, Map<Long, Boolean> deviceIdsEnabled) {
public String setEnabled(@Auth TestPrincipal principal, Map<Byte, Boolean> deviceIdsEnabled) {
final StringBuilder response = new StringBuilder();
for (Entry<Long, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
for (Entry<Byte, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow();
DevicesHelper.setEnabled(device, deviceIdEnabled.getValue());
@@ -462,7 +466,7 @@ class AuthEnablementRefreshRequirementProviderTest {
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Long::valueOf)
.map(Byte::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
@@ -471,7 +475,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@POST
@Path("/account/disablePrimaryDeviceAndDeleteDevice/{deviceId}")
@ChangesDeviceEnabledState
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) {
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") byte deviceId) {
DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice().orElseThrow(), false);

View File

@@ -150,7 +150,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticate() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@@ -180,7 +180,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateNonDefaultDevice() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 2;
final byte deviceId = 2;
final String password = "12345";
final Account account = mock(Account.class);
@@ -214,7 +214,7 @@ class BaseAccountAuthenticatorTest {
@CartesianTest.Values(booleans = {true, false}) final boolean deviceEnabled,
@CartesianTest.Values(booleans = {true, false}) final boolean authenticatedDeviceIsPrimary) {
final UUID uuid = UUID.randomUUID();
final long deviceId = authenticatedDeviceIsPrimary ? 1 : 2;
final byte deviceId = (byte) (authenticatedDeviceIsPrimary ? 1 : 2);
final String password = "12345";
final Account account = mock(Account.class);
@@ -253,7 +253,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateV1() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@@ -290,7 +290,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateDeviceNotFound() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@@ -312,13 +312,13 @@ class BaseAccountAuthenticatorTest {
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid + "." + (deviceId + 1), password), true);
assertThat(maybeAuthenticatedAccount).isEmpty();
verify(account).getDevice(deviceId + 1);
verify(account).getDevice((byte) (deviceId + 1));
}
@Test
void testAuthenticateIncorrectPassword() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@@ -365,8 +365,9 @@ class BaseAccountAuthenticatorTest {
@ParameterizedTest
@MethodSource
void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier, final long expectedDeviceId) {
final Pair<String, Long> identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username);
void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier,
final byte expectedDeviceId) {
final Pair<String, Byte> identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username);
assertEquals(expectedIdentifier, identifierAndDeviceId.first());
assertEquals(expectedDeviceId, identifierAndDeviceId.second());
@@ -376,7 +377,7 @@ class BaseAccountAuthenticatorTest {
return Stream.of(
Arguments.of("", "", Device.PRIMARY_ID),
Arguments.of("test", "test", Device.PRIMARY_ID),
Arguments.of("test.7", "test", 7));
Arguments.of("test.7", "test", (byte) 7));
}
@ParameterizedTest

View File

@@ -34,11 +34,11 @@ class CertificateGeneratorTest {
final CertificateGenerator certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(SIGNING_CERTIFICATE), Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn(4L);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn((byte) 4);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
}
}

View File

@@ -32,7 +32,7 @@ class OptionalAccessTest {
void testUnidentifiedMissingTargetDevice() {
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq(10))).thenReturn(Optional.empty());
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try {
@@ -46,7 +46,7 @@ class OptionalAccessTest {
void testUnidentifiedBadTargetDevice() {
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq(10))).thenReturn(Optional.empty());
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try {

View File

@@ -18,9 +18,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
public class MockAuthenticationInterceptor implements ServerInterceptor {
@Nullable
private Pair<UUID, Long> authenticatedDevice;
private Pair<UUID, Byte> authenticatedDevice;
public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) {
public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) {
authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
}

View File

@@ -10,8 +10,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
@@ -299,7 +299,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@Test
@@ -328,7 +328,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@Test
@@ -344,7 +344,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null);
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@ParameterizedTest

View File

@@ -160,7 +160,7 @@ class AccountControllerV2Test {
}
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@@ -481,7 +481,7 @@ class AccountControllerV2Test {
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@@ -661,7 +661,7 @@ class AccountControllerV2Test {
assertEquals(account.isUnrestrictedUnidentifiedAccess(),
structuredResponse.data().account().allowSealedSenderFromAnyone());
final Set<Long> deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final Set<Byte> deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
// all devices should be present
structuredResponse.data().devices().forEach(deviceDataReport -> {
@@ -704,8 +704,8 @@ class AccountControllerV2Test {
buildTestAccountForDataReport(UUID.randomUUID(), exampleNumber1,
true, true,
Collections.emptyList(),
List.of(new DeviceData(1, account1Device1LastSeen, account1Device1Created, null),
new DeviceData(2, account1Device2LastSeen, account1Device2Created, "OWP"))),
List.of(new DeviceData(Device.PRIMARY_ID, account1Device1LastSeen, account1Device1Created, null),
new DeviceData((byte) 2, account1Device2LastSeen, account1Device2Created, "OWP"))),
String.format("""
# Account
Phone number: %s
@@ -730,7 +730,7 @@ class AccountControllerV2Test {
buildTestAccountForDataReport(UUID.randomUUID(), account2PhoneNumber,
false, true,
List.of(new AccountBadge("badge_a", badgeAExpiration, true)),
List.of(new DeviceData(1, account2Device1LastSeen, account2Device1Created, "OWI"))),
List.of(new DeviceData(Device.PRIMARY_ID, account2Device1LastSeen, account2Device1Created, "OWI"))),
String.format("""
# Account
Phone number: %s
@@ -756,7 +756,7 @@ class AccountControllerV2Test {
List.of(
new AccountBadge("badge_b", badgeBExpiration, true),
new AccountBadge("badge_c", badgeCExpiration, false)),
List.of(new DeviceData(1, account3Device1LastSeen, account3Device1Created, "OWA"))),
List.of(new DeviceData(Device.PRIMARY_ID, account3Device1LastSeen, account3Device1Created, "OWA"))),
String.format("""
# Account
Phone number: %s
@@ -825,7 +825,7 @@ class AccountControllerV2Test {
return account;
}
private record DeviceData(long id, Instant lastSeen, Instant created, @Nullable String userAgent) {
private record DeviceData(byte id, Instant lastSeen, Instant created, @Nullable String userAgent) {
}

View File

@@ -8,7 +8,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
@@ -99,6 +99,8 @@ class DeviceControllerTest {
private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static TestClock testClock = TestClock.now();
private static final byte NEXT_DEVICE_ID = 42;
private static DeviceController deviceController = new DeviceController(
generateLinkDeviceSecret(),
accountsManager,
@@ -137,9 +139,9 @@ class DeviceControllerTest {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(1L);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(account.getNextDeviceId()).thenReturn(42L);
when(account.getNextDeviceId()).thenReturn(NEXT_DEVICE_ID);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
@@ -154,9 +156,9 @@ class DeviceControllerTest {
AccountsHelper.setupMockUpdate(accountsManager);
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
}
@AfterEach
@@ -199,9 +201,9 @@ class DeviceControllerTest {
MediaType.APPLICATION_JSON_TYPE),
DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(commands).set(anyString(), anyString(), any());
}
@@ -315,7 +317,7 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class);
verify(account).addDevice(deviceCaptor.capture());
@@ -335,7 +337,7 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
@@ -751,7 +753,7 @@ class DeviceControllerTest {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
final long deviceId = 2;
final byte deviceId = 2;
final Response response = resources
.getJerseyTest()
@@ -785,10 +787,10 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(403);
verify(messagesManager, never()).clear(any(), anyLong());
verify(messagesManager, never()).clear(any(), anyByte());
verify(accountsManager, never()).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyLong());
verify(keysManager, never()).delete(any(), anyLong());
verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyByte());
verify(keysManager, never()).delete(any(), anyByte());
}
}

View File

@@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.clearInvocations;
@@ -84,6 +84,11 @@ class KeysControllerTest {
private static final UUID NOT_EXISTS_UUID = UUID.randomUUID();
private static final byte SAMPLE_DEVICE_ID = 1;
private static final byte SAMPLE_DEVICE_ID2 = 2;
private static final byte SAMPLE_DEVICE_ID3 = 3;
private static final byte SAMPLE_DEVICE_ID4 = 4;
private static final int SAMPLE_REGISTRATION_ID = 999;
private static final int SAMPLE_REGISTRATION_ID2 = 1002;
private static final int SAMPLE_REGISTRATION_ID4 = 1555;
@@ -180,6 +185,11 @@ class KeysControllerTest {
final List<Device> allDevices = List.of(sampleDevice, sampleDevice2, sampleDevice3, sampleDevice4);
final byte sampleDeviceId = 1;
final byte sampleDevice2Id = 2;
final byte sampleDevice3Id = 3;
final byte sampleDevice4Id = 4;
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
@@ -199,18 +209,18 @@ class KeysControllerTest {
when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2);
when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3);
when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null);
when(sampleDevice.getId()).thenReturn(1L);
when(sampleDevice2.getId()).thenReturn(2L);
when(sampleDevice3.getId()).thenReturn(3L);
when(sampleDevice4.getId()).thenReturn(4L);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
when(sampleDevice4.getId()).thenReturn(sampleDevice4Id);
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI);
when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.empty());
when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
@@ -225,17 +235,21 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.store(any(), anyByte(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));
when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@@ -267,8 +281,8 @@ class KeysControllerTest {
assertThat(result.getCount()).isEqualTo(5);
assertThat(result.getPqCount()).isEqualTo(5);
verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getEcCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID);
}
@Test
@@ -284,7 +298,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(Device.PRIMARY_ID, test));
}
@@ -303,7 +317,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey));
verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(Device.PRIMARY_ID, replacementKey));
}
@@ -329,20 +343,20 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@@ -353,15 +367,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -376,15 +390,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -398,14 +412,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -420,15 +434,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -444,14 +458,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -481,14 +495,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@@ -534,10 +548,14 @@ class KeysControllerTest {
@Test
void validMultiRequestTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@@ -548,56 +566,62 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey();
long registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId();
byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId();
assertEquals(SAMPLE_KEY, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId();
assertEquals(SAMPLE_KEY2, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId();
assertEquals(SAMPLE_KEY4, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS);
}
@Test
void validMultiRequestPqTestV2() {
when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@@ -609,51 +633,51 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey();
KEMSignedPreKey pqPreKey = results.getDevice(SAMPLE_DEVICE_ID).getPqPreKey();
int registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId();
byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId();
assertEquals(SAMPLE_KEY, preKey);
assertEquals(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
pqPreKey = results.getDevice(2).getPqPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey();
pqPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getPqPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId();
assertThat(preKey).isNull();
assertEquals(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
pqPreKey = results.getDevice(4).getPqPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey();
pqPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getPqPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId();
assertEquals(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takePQ(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).takePQ(EXISTS_UUID, 4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS);
}
@@ -719,7 +743,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@@ -750,7 +775,8 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
@@ -852,7 +878,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), eq(signedPreKey),
isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@@ -884,7 +911,8 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
@@ -928,7 +956,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
List<ECPreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -953,7 +982,8 @@ class KeysControllerTest {
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, 2L, AuthHelper.VALID_PASSWORD_3_LINKED))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, SAMPLE_DEVICE_ID2,
AuthHelper.VALID_PASSWORD_3_LINKED))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);

View File

@@ -135,15 +135,15 @@ class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111");
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final int SINGLE_DEVICE_ID1 = 1;
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222");
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
private static final byte MULTI_DEVICE_ID1 = 1;
private static final byte MULTI_DEVICE_ID2 = 2;
private static final byte MULTI_DEVICE_ID3 = 3;
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
@@ -225,7 +225,8 @@ class MessageControllerTest {
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
}
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);
@@ -526,13 +527,14 @@ class MessageControllerTest {
final UUID updatedPniOne = UUID.randomUUID();
List<Envelope> envelopes = List.of(
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2,
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, (byte) 2,
AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false),
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2,
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid,
(byte) 2,
AuthHelper.VALID_UUID, null, null, 0, true)
);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean()))
.thenReturn(Mono.just(new Pair<>(envelopes, false)));
final String userAgent = "Test-UA";
@@ -580,13 +582,13 @@ class MessageControllerTest {
final long timestampTwo = 313388;
final List<Envelope> messages = List.of(
generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), 2,
generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), (byte) 2,
AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0),
generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo,
UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)
UUID.randomUUID(), (byte) 2, AuthHelper.VALID_UUID, null, null, 0)
);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean()))
.thenReturn(Mono.just(new Pair<>(messages, false)));
Response response =
@@ -606,24 +608,24 @@ class MessageControllerTest {
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid1, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid2, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))));
System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0))));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid3, null))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
UUID uuid4 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid4, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid4, null))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("Oh No")));
Response response = resources.getJerseyTest()
@@ -633,7 +635,7 @@ class MessageControllerTest {
.delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L),
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
response = resources.getJerseyTest()
@@ -879,7 +881,7 @@ class MessageControllerTest {
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true,
List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE));
@@ -919,7 +921,7 @@ class MessageControllerTest {
);
}
private static void writePayloadDeviceId(ByteBuffer bb, long deviceId) {
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
@@ -1155,7 +1157,7 @@ class MessageControllerTest {
if (known) {
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]);
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]);
}
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
@@ -1250,7 +1252,7 @@ class MessageControllerTest {
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class));
assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))),
new MismatchedDevices(Collections.emptyList(), List.of(MULTI_DEVICE_ID3)))),
mismatchedDevices);
}
@@ -1298,7 +1300,8 @@ class MessageControllerTest {
assertEquals(1, staleDevices.size());
assertEquals(serviceIdentifier, staleDevices.get(0).uuid());
assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices()));
assertEquals(Set.of(MULTI_DEVICE_ID1, MULTI_DEVICE_ID2),
new HashSet<>(staleDevices.get(0).devices().staleDevices()));
}
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
@@ -1380,12 +1383,12 @@ class MessageControllerTest {
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
@@ -1413,14 +1416,14 @@ class MessageControllerTest {
private static Recipient genRecipient(Random rng) {
UUID u1 = UUID.randomUUID(); // non-null
long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
ByteBuffer bb = ByteBuffer.wrap(bytes);
writePayloadDeviceId(bb, expected);
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
@@ -1434,15 +1437,17 @@ class MessageControllerTest {
byte[] bytes = new byte[12];
// some static test cases
for (long i = 1L; i <= 10L; i++) {
for (byte i = 1; i <= 10; i++) {
roundTripVarint(i, bytes);
}
roundTripVarint(Long.MAX_VALUE, bytes);
roundTripVarint(Byte.MAX_VALUE, bytes);
for (int i = 0; i < 1000; i++) {
// we need to ensure positive device IDs
long start = rng.nextLong() & Long.MAX_VALUE;
if (start == 0L) start = 1L;
byte start = (byte) rng.nextInt(128);
if (start == 0L) {
start = 1;
}
// run the test for this case
roundTripVarint(start, bytes);

View File

@@ -75,12 +75,12 @@ class OutgoingMessageEntityTest {
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA");
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA");
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),
account,
123L,
(byte) 123,
System.currentTimeMillis(),
false,
true,

View File

@@ -170,7 +170,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void deleteAccountLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
@@ -215,7 +215,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void setRegistrationLockLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
@@ -240,7 +240,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void clearRegistrationLockLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,

View File

@@ -7,7 +7,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
@@ -88,7 +88,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
return CompletableFuture.completedFuture(account);
});
when(accountsManager.updateDeviceAsync(any(), anyLong(), any()))
when(accountsManager.updateDeviceAsync(any(), anyByte(), any()))
.thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Device device = account.getDevice(invocation.getArgument(1)).orElseThrow();
@@ -99,8 +99,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
return CompletableFuture.completedFuture(account);
});
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
}
@@ -120,7 +120,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final String linkedDeviceName = "A linked device";
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(Device.PRIMARY_ID + 1);
when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(linkedDevice.getCreated()).thenReturn(linkedDeviceCreated.toEpochMilli());
when(linkedDevice.getLastSeen()).thenReturn(linkedDeviceLastSeen.toEpochMilli());
when(linkedDevice.getName())
@@ -147,7 +147,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@Test
void removeDevice() {
final long deviceId = 17;
final byte deviceId = 17;
final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId)
@@ -167,15 +167,15 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@Test
void removeDeviceNonPrimaryAuthenticated() {
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(17)
.build()));
}
@ParameterizedTest
@ValueSource(longs = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1})
void setDeviceName(final long deviceId) {
@ValueSource(bytes = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1})
void setDeviceName(final byte deviceId) {
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
@@ -212,7 +212,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@ParameterizedTest
@MethodSource
void setPushToken(final long deviceId,
void setPushToken(final byte deviceId,
final SetPushTokenRequest request,
@Nullable final String expectedApnsToken,
@Nullable final String expectedApnsVoipToken,
@@ -238,7 +238,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final Stream.Builder<Arguments> streamBuilder = Stream.builder();
for (final long deviceId : new long[] { Device.PRIMARY_ID, Device.PRIMARY_ID + 1 }) {
for (final byte deviceId : new byte[]{Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) {
streamBuilder.add(Arguments.of(deviceId,
SetPushTokenRequest.newBuilder()
.setApnsTokenRequest(SetPushTokenRequest.ApnsTokenRequest.newBuilder()
@@ -284,7 +284,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
}
private static Stream<Arguments> setPushTokenUnchanged() {
@@ -323,7 +323,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
}
private static Stream<Arguments> setPushTokenIllegalArgument() {
@@ -342,7 +342,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@ParameterizedTest
@MethodSource
void clearPushToken(final long deviceId,
void clearPushToken(final byte deviceId,
@Nullable final String apnsToken,
@Nullable final String apnsVoipToken,
@Nullable final String fcmToken,
@@ -379,17 +379,17 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
Arguments.of(Device.PRIMARY_ID, null, "apns-voip-token", null, "OWI"),
Arguments.of(Device.PRIMARY_ID, null, null, "fcm-token", "OWA"),
Arguments.of(Device.PRIMARY_ID, null, null, null, null),
Arguments.of(Device.PRIMARY_ID + 1, "apns-token", null, null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, "apns-token", "apns-voip-token", null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, null, "apns-voip-token", null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, null, null, "fcm-token", "OWA"),
Arguments.of(Device.PRIMARY_ID + 1, null, null, null, null)
Arguments.of((byte) (Device.PRIMARY_ID + 1), "apns-token", null, null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), "apns-token", "apns-voip-token", null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, "apns-voip-token", null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, null, "fcm-token", "OWA"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, null, null, null)
);
}
@CartesianTest
void setCapabilities(
@CartesianTest.Values(longs = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) final long deviceId,
@CartesianTest.Values(bytes = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) final byte deviceId,
@CartesianTest.Values(booleans = {true, false}) final boolean storage,
@CartesianTest.Values(booleans = {true, false}) final boolean transfer,
@CartesianTest.Values(booleans = {true, false}) final boolean pni,

View File

@@ -31,7 +31,7 @@ public final class GrpcTestUtils {
final MockAuthenticationInterceptor mockAuthenticationInterceptor,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
final UUID authenticatedAci,
final long authenticatedDeviceId,
final byte authenticatedDeviceId,
final BindableService service) {
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry()

View File

@@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
@@ -184,8 +184,8 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
}
@ParameterizedTest
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final long deviceId) {
@ValueSource(bytes = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final byte deviceId) {
final UUID accountIdentifier = UUID.randomUUID();
final byte[] unidentifiedAccessKey = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
@@ -195,7 +195,7 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))

View File

@@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
@@ -151,7 +151,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
preKeys.add(new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()));
}
when(keysManager.storeEcOneTimePreKeys(any(), anyLong(), any()))
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
@@ -222,7 +222,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
preKeys.add(KeysHelper.signedKEMPreKey(keyId, identityKeyPair));
}
when(keysManager.storeKemOneTimePreKeys(any(), anyLong(), any()))
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
@@ -294,9 +294,9 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void setSignedPreKey(final org.signal.chat.common.IdentityType identityType) {
when(accountsManager.updateDeviceAsync(any(), anyLong(), any())).thenAnswer(invocation -> {
when(accountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final long deviceId = invocation.getArgument(1);
final byte deviceId = invocation.getArgument(1);
final Consumer<Device> deviceUpdater = invocation.getArgument(2);
account.getDevice(deviceId).ifPresent(deviceUpdater);
@@ -477,13 +477,16 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
when(accountsManager.getByServiceIdentifierAsync(argThat(serviceIdentifier -> serviceIdentifier.uuid().equals(identifier))))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Map<Long, ECPreKey> ecOneTimePreKeys = new HashMap<>();
final Map<Long, KEMSignedPreKey> kemPreKeys = new HashMap<>();
final Map<Long, ECSignedPreKey> ecSignedPreKeys = new HashMap<>();
final Map<Byte, ECPreKey> ecOneTimePreKeys = new HashMap<>();
final Map<Byte, KEMSignedPreKey> kemPreKeys = new HashMap<>();
final Map<Byte, ECSignedPreKey> ecSignedPreKeys = new HashMap<>();
final Map<Long, Device> devices = new HashMap<>();
final Map<Byte, Device> devices = new HashMap<>();
for (final long deviceId : List.of(1, 2)) {
final byte deviceId1 = 1;
final byte deviceId2 = 2;
for (final byte deviceId : List.of(deviceId1, deviceId2)) {
ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey()));
kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair));
ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair));
@@ -518,18 +521,18 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).signature()))
.build())
.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
.setKeyId(ecOneTimePreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(deviceId1).serializedPublicKey()))
.build())
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
.setKeyId(kemPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(deviceId1).signature()))
.build())
.build())
.build();
@@ -537,8 +540,8 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
assertEquals(expectedResponse, response);
}
when(keysManager.takeEC(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takeEC(identifier, deviceId2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takePQ(identifier, deviceId2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
{
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
@@ -552,25 +555,25 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).signature()))
.build())
.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
.setKeyId(ecOneTimePreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(deviceId1).serializedPublicKey()))
.build())
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
.setKeyId(kemPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(deviceId1).signature()))
.build())
.build())
.putPreKeys(2, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(2L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(2L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(2L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId2).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId2).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId2).signature()))
.build())
.build())
.build();
@@ -593,15 +596,15 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
}
@ParameterizedTest
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final long deviceId) {
@ValueSource(bytes = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final byte deviceId) {
final UUID accountIdentifier = UUID.randomUUID();
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
@@ -621,7 +624,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));

View File

@@ -55,7 +55,7 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
protected static final long AUTHENTICATED_DEVICE_ID = Device.PRIMARY_ID;
protected static final byte AUTHENTICATED_DEVICE_ID = Device.PRIMARY_ID;
private AutoCloseable mocksCloseable;

View File

@@ -54,7 +54,7 @@ class APNSenderTest {
apnsClient = mock(ApnsClient.class);
apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, BUNDLE_ID);
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(destinationDevice));
when(destinationDevice.getApnId()).thenReturn(DESTINATION_DEVICE_TOKEN);
}

View File

@@ -30,7 +30,6 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -54,7 +53,7 @@ class ApnPushNotificationSchedulerTest {
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32);
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@@ -98,12 +97,12 @@ class ApnPushNotificationSchedulerTest {
final List<String> pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Optional<Pair<String, Long>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
final Optional<Pair<String, Byte>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
pendingDestinations.get(0));
assertTrue(maybeUuidAndDeviceId.isPresent());
assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first());
assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second());
assertEquals(DEVICE_ID, maybeUuidAndDeviceId.get().second());
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
@@ -236,8 +235,6 @@ class ApnPushNotificationSchedulerTest {
final AccountsManager accountsManager = mock(AccountsManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
apnPushNotificationScheduler = new ApnPushNotificationScheduler(redisCluster, apnSender,
accountsManager, dedicatedThreadCount);

View File

@@ -76,7 +76,7 @@ class ClientPresenceManagerTest {
@Test
void testIsPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
@@ -87,7 +87,7 @@ class ClientPresenceManagerTest {
@Test
void testIsLocallyPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
@@ -100,7 +100,7 @@ class ClientPresenceManagerTest {
@Test
void testLocalDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final AtomicInteger displacementCounter = new AtomicInteger(0);
final DisplacedPresenceListener displacementListener = connectedElsewhere -> displacementCounter.incrementAndGet();
@@ -117,7 +117,7 @@ class ClientPresenceManagerTest {
@Test
void testRemoteDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
@@ -135,7 +135,7 @@ class ClientPresenceManagerTest {
@Test
void testRemoteDisplacementAfterTopologyChange() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
@@ -157,7 +157,7 @@ class ClientPresenceManagerTest {
@Test
void testClearPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
@@ -210,7 +210,7 @@ class ClientPresenceManagerTest {
@Test
void testInitialPresenceExpiration() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
@@ -225,7 +225,7 @@ class ClientPresenceManagerTest {
@Test
void testRenewPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String presenceKey = ClientPresenceManager.getPresenceKey(accountUuid, deviceId);
@@ -252,7 +252,7 @@ class ClientPresenceManagerTest {
@Test
void testExpiredPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
@@ -266,7 +266,7 @@ class ClientPresenceManagerTest {
}
private void addClientPresence(final String managerId) {
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7);
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), (byte) 7);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().set(clientPresenceKey, managerId);
@@ -278,17 +278,17 @@ class ClientPresenceManagerTest {
void testClearAllOnStop() {
final int localAccounts = 10;
final UUID[] localUuids = new UUID[localAccounts];
final long[] localDeviceIds = new long[localAccounts];
final byte[] localDeviceIds = new byte[localAccounts];
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
localDeviceIds[i] = (byte) i;
clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP);
}
final UUID displacedAccountUuid = UUID.randomUUID();
final long displacedAccountDeviceId = 7;
final byte displacedAccountDeviceId = 7;
clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> connection.sync()
@@ -299,7 +299,7 @@ class ClientPresenceManagerTest {
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
localDeviceIds[i] = (byte) i;
assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i]));
}
@@ -346,7 +346,7 @@ class ClientPresenceManagerTest {
@Test
void testSetPresentRemotely() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);
@@ -360,7 +360,7 @@ class ClientPresenceManagerTest {
@Test
void testDisconnectPresenceLocally() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);
@@ -374,7 +374,7 @@ class ClientPresenceManagerTest {
@Test
void testDisconnectPresenceRemotely() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);

View File

@@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
@@ -42,7 +42,7 @@ class MessageSenderTest {
private MessageSender messageSender;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
@BeforeEach
void setUp() {
@@ -73,7 +73,7 @@ class MessageSenderTest {
ArgumentCaptor<MessageProtos.Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(
MessageProtos.Envelope.class);
verify(messagesManager).insert(any(), anyLong(), envelopeArgumentCaptor.capture());
verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture());
assertTrue(envelopeArgumentCaptor.getValue().getEphemeral());
@@ -87,7 +87,7 @@ class MessageSenderTest {
messageSender.sendMessage(account, device, message, true);
verify(messagesManager, never()).insert(any(), anyLong(), any());
verify(messagesManager, never()).insert(any(), anyByte(), any());
verifyNoInteractions(pushNotificationManager);
}

View File

@@ -1,6 +1,16 @@
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import com.google.protobuf.ByteString;
import java.time.Duration;
import java.util.Random;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -11,17 +21,6 @@ import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import java.time.Duration;
import java.util.Random;
import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
class ProvisioningManagerTest {
private ProvisioningManager provisioningManager;
@@ -44,7 +43,7 @@ class ProvisioningManagerTest {
@Test
void sendProvisioningMessage() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);
@@ -65,7 +64,7 @@ class ProvisioningManagerTest {
@Test
void removeListener() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);

View File

@@ -35,7 +35,7 @@ class PushLatencyManagerTest {
@MethodSource
void testTakeRecord(final boolean isVoip, final boolean isUrgent) throws ExecutionException, InterruptedException {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final Instant pushTimestamp = Instant.now();

View File

@@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -85,15 +86,16 @@ class AccountTest {
when(agingSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31));
when(agingSecondaryDevice.isEnabled()).thenReturn(false);
when(agingSecondaryDevice.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(agingSecondaryDevice.getId()).thenReturn(deviceId2);
when(recentSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(1));
when(recentSecondaryDevice.isEnabled()).thenReturn(true);
when(recentSecondaryDevice.getId()).thenReturn(2L);
when(recentSecondaryDevice.getId()).thenReturn(deviceId2);
when(oldSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldSecondaryDevice.isEnabled()).thenReturn(false);
when(oldSecondaryDevice.getId()).thenReturn(2L);
when(oldSecondaryDevice.getId()).thenReturn(deviceId2);
when(senderKeyCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
@@ -143,17 +145,17 @@ class AccountTest {
new DeviceCapabilities(true, true, false, false));
when(pniIncapableExpiredDevice.isEnabled()).thenReturn(false);
when(storiesCapableDevice.getId()).thenReturn(1L);
when(storiesCapableDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(storiesCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesCapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(2L);
when(storiesCapableDevice.getId()).thenReturn(deviceId2);
when(storiesIncapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesIncapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(3L);
when(storiesCapableDevice.getId()).thenReturn((byte) 3);
when(storiesIncapableExpiredDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesIncapableExpiredDevice.isEnabled()).thenReturn(false);
@@ -192,10 +194,11 @@ class AccountTest {
when(disabledPrimaryDevice.isEnabled()).thenReturn(false);
when(disabledLinkedDevice.isEnabled()).thenReturn(false);
when(enabledPrimaryDevice.getId()).thenReturn(1L);
when(enabledLinkedDevice.getId()).thenReturn(2L);
when(disabledPrimaryDevice.getId()).thenReturn(1L);
when(disabledLinkedDevice.getId()).thenReturn(2L);
when(enabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
final byte deviceId2 = 2;
when(enabledLinkedDevice.getId()).thenReturn(deviceId2);
when(disabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(disabledLinkedDevice.getId()).thenReturn(deviceId2);
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice)).isEnabled());
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice, enabledLinkedDevice)).isEnabled());
@@ -214,15 +217,15 @@ class AccountTest {
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
when(transferCapablePrimaryDevice.getId()).thenReturn(1L);
when(transferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(transferCapablePrimaryDevice.isPrimary()).thenReturn(true);
when(transferCapablePrimaryDevice.getCapabilities()).thenReturn(transferCapabilities);
when(nonTransferCapablePrimaryDevice.getId()).thenReturn(1L);
when(nonTransferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(nonTransferCapablePrimaryDevice.isPrimary()).thenReturn(true);
when(nonTransferCapablePrimaryDevice.getCapabilities()).thenReturn(nonTransferCapabilities);
when(transferCapableLinkedDevice.getId()).thenReturn(2L);
when(transferCapableLinkedDevice.getId()).thenReturn((byte) 2);
when(transferCapableLinkedDevice.isPrimary()).thenReturn(false);
when(transferCapableLinkedDevice.getCapabilities()).thenReturn(transferCapabilities);
@@ -311,21 +314,31 @@ class AccountTest {
final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), devices, new byte[0]);
assertThat(account.getNextDeviceId()).isEqualTo(2L);
final byte deviceId2 = 2;
assertThat(account.getNextDeviceId()).isEqualTo(deviceId2);
account.addDevice(createDevice(2L));
account.addDevice(createDevice(deviceId2));
assertThat(account.getNextDeviceId()).isEqualTo(3L);
final byte deviceId3 = 3;
assertThat(account.getNextDeviceId()).isEqualTo(deviceId3);
account.addDevice(createDevice(3L));
account.addDevice(createDevice(deviceId3));
setEnabled(account.getDevice(2L).orElseThrow(), false);
setEnabled(account.getDevice(deviceId2).orElseThrow(), false);
assertThat(account.getNextDeviceId()).isEqualTo(4L);
assertThat(account.getNextDeviceId()).isEqualTo((byte) 4);
account.removeDevice(2L);
account.removeDevice(deviceId2);
assertThat(account.getNextDeviceId()).isEqualTo(2L);
assertThat(account.getNextDeviceId()).isEqualTo(deviceId2);
while (account.getNextDeviceId() < Device.MAXIMUM_DEVICE_ID) {
account.addDevice(createDevice(account.getNextDeviceId()));
}
account.addDevice(createDevice(Device.MAXIMUM_DEVICE_ID));
assertThatThrownBy(account::getNextDeviceId).isInstanceOf(RuntimeException.class);
}
@Test
@@ -399,7 +412,7 @@ class AccountTest {
final Device disabledPrimary = mock(Device.class);
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
final long linked1DeviceId = Device.PRIMARY_ID + 1;
final byte linked1DeviceId = Device.PRIMARY_ID + 1;
final Device enabledLinked1 = mock(Device.class);
when(enabledLinked1.isEnabled()).thenReturn(true);
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);
@@ -407,7 +420,7 @@ class AccountTest {
final Device disabledLinked1 = mock(Device.class);
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);
final long linked2DeviceId = Device.PRIMARY_ID + 2;
final byte linked2DeviceId = Device.PRIMARY_ID + 2;
final Device enabledLinked2 = mock(Device.class);
when(enabledLinked2.isEnabled()).thenReturn(true);
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);

View File

@@ -178,8 +178,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);

View File

@@ -141,8 +141,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()),
a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(1);
a.addDevice(DevicesHelper.createDevice(1));
a.removeDevice(Device.PRIMARY_ID);
a.addDevice(DevicesHelper.createDevice(Device.PRIMARY_ID));
});
uuid = account.getUuid();
@@ -212,7 +212,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
}, mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final long deviceId, final Consumer<Device> deviceMutation) {
private CompletableFuture<?> modifyDevice(final UUID uuid, final byte deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow();

View File

@@ -876,7 +876,7 @@ class AccountsManagerTest {
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@@ -909,7 +909,7 @@ class AccountsManagerTest {
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@@ -1064,7 +1064,8 @@ class AccountsManagerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber(
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)),
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of((byte) 1, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any());
@@ -1107,24 +1108,26 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2;
final byte deviceId3 = 3;
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L, 3L)));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3)));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final List<Device> devices = List.of(
DevicesHelper.createDevice(1L, 0L, 101),
DevicesHelper.createDevice(2L, 0L, 102),
DevicesHelper.createDisabledDevice(3L, 103));
DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102),
DevicesHelper.createDisabledDevice(deviceId3, 103));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
@@ -1140,7 +1143,8 @@ class AccountsManagerTest {
verify(keysManager).delete(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).storeEcSignedPreKeys(newPni, newSignedKeys);
verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(newPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
verifyNoMoreInteractions(keysManager);
}
@@ -1153,19 +1157,22 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2;
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
assertThrows(MismatchedDevicesException.class,
() -> accountsManager.changeNumber(
@@ -1189,18 +1196,20 @@ class AccountsManagerTest {
@Test
void testPniUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -1217,7 +1226,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI stuff should
@@ -1236,26 +1245,29 @@ class AccountsManagerTest {
@Test
void testPniPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -1270,7 +1282,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
@@ -1287,23 +1299,26 @@ class AccountsManagerTest {
verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys);
// only the pq key for the already-pq-enabled device should be saved
verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(oldPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
}
@Test
void testPniNonPqToPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
@@ -1312,7 +1327,7 @@ class AccountsManagerTest {
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -1327,7 +1342,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
@@ -1348,19 +1363,21 @@ class AccountsManagerTest {
@Test
void testPniUpdate_incompleteKeys() {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final byte deviceId2 = 2;
final byte deviceId3 = 3;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
2L, KeysHelper.signedECPreKey(1, identityKeyPair),
3L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
deviceId2, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -1375,21 +1392,22 @@ class AccountsManagerTest {
@Test
void testPniPqUpdate_incompleteKeys() {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());

View File

@@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -75,6 +76,9 @@ import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsTest {
private static final byte DEVICE_ID_1 = 1;
private static final byte DEVICE_ID_2 = 2;
private static final String BASE_64_URL_USERNAME_HASH_1 = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE";
private static final String BASE_64_URL_USERNAME_HASH_2 = "NLUom-CHwtemcdvOTTXdmXmzRIV7F05leS8lwkVK_vc";
private static final String BASE_64_URL_ENCRYPTED_USERNAME_1 = "md1votbj9r794DsqTNrBqA";
@@ -156,7 +160,7 @@ class AccountsTest {
@Test
void testStore() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
@@ -179,7 +183,7 @@ class AccountsTest {
void testStoreRecentlyDeleted() {
final UUID originalUuid = UUID.randomUUID();
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
@@ -205,7 +209,7 @@ class AccountsTest {
@Test
void testStoreMulti() {
final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
accounts.create(account);
@@ -218,13 +222,13 @@ class AccountsTest {
@Test
void testRetrieve() {
final List<Device> devicesFirst = List.of(generateDevice(1), generateDevice(2));
final List<Device> devicesFirst = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
UUID uuidFirst = UUID.randomUUID();
UUID pniFirst = UUID.randomUUID();
Account accountFirst = generateAccount("+14151112222", uuidFirst, pniFirst, devicesFirst);
final List<Device> devicesSecond = List.of(generateDevice(1), generateDevice(2));
final List<Device> devicesSecond = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
UUID uuidSecond = UUID.randomUUID();
UUID pniSecond = UUID.randomUUID();
@@ -263,7 +267,7 @@ class AccountsTest {
@Test
void testRetrieveNoPni() throws JsonProcessingException {
final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final UUID uuid = UUID.randomUUID();
final Account account = generateAccount("+14151112222", uuid, null, devices);
@@ -321,7 +325,7 @@ class AccountsTest {
@Test
void testOverwrite() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
UUID firstUuid = UUID.randomUUID();
UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
@@ -346,7 +350,7 @@ class AccountsTest {
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device));
final boolean freshUser = accounts.create(account);
@@ -356,7 +360,7 @@ class AccountsTest {
assertPhoneNumberConstraintExists("+14151112222", firstUuid);
assertPhoneNumberIdentifierConstraintExists(firstPni, firstUuid);
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.create(invalidAccount));
@@ -364,7 +368,7 @@ class AccountsTest {
@Test
void testUpdate() {
Device device = generateDevice (1 );
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@@ -389,7 +393,7 @@ class AccountsTest {
assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
Account unknownAccount = generateAccount("+14151113333", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
@@ -452,10 +456,10 @@ class AccountsTest {
@Test
void testDelete() {
final Device deletedDevice = generateDevice(1);
final Device deletedDevice = generateDevice(DEVICE_ID_1);
final Account deletedAccount = generateAccount("+14151112222", UUID.randomUUID(),
UUID.randomUUID(), List.of(deletedDevice));
final Device retainedDevice = generateDevice(1);
final Device retainedDevice = generateDevice(DEVICE_ID_1);
final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(),
UUID.randomUUID(), List.of(retainedDevice));
@@ -485,7 +489,7 @@ class AccountsTest {
{
final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(),
UUID.randomUUID(), List.of(generateDevice(1)));
UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
final boolean freshUser = accounts.create(recreatedAccount);
@@ -501,7 +505,7 @@ class AccountsTest {
@Test
void testMissing() {
Device device = generateDevice (1 );
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@@ -518,7 +522,7 @@ class AccountsTest {
assertThat(accounts.getByAccountIdentifierAsync(UUID.randomUUID()).join()).isEmpty();
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@@ -530,7 +534,7 @@ class AccountsTest {
assertThat(accounts.getByPhoneNumberIdentifierAsync(UUID.randomUUID()).join()).isEmpty();
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@@ -544,7 +548,7 @@ class AccountsTest {
assertThat(accounts.getByE164Async(e164).join()).isEmpty();
final Account account =
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@@ -553,7 +557,7 @@ class AccountsTest {
@Test
void testCanonicallyDiscoverableSet() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
account.setDiscoverableByPhoneNumber(false);
accounts.create(account);
@@ -576,7 +580,7 @@ class AccountsTest {
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
@@ -631,10 +635,10 @@ class AccountsTest {
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Device existingDevice = generateDevice(1);
final Device existingDevice = generateDevice(DEVICE_ID_1);
final Account existingAccount = generateAccount(targetNumber, UUID.randomUUID(), targetPni, List.of(existingDevice));
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
@@ -653,7 +657,7 @@ class AccountsTest {
final String originalNumber = "+14151112222";
final String targetNumber = "+14151113333";
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@@ -969,7 +973,48 @@ class AccountsTest {
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isPresent();
}
private static Device generateDevice(long id) {
@Test
public void testInvalidDeviceIdDeserialization() throws Exception {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
final Device device2 = generateDevice((byte) 64);
account.addDevice(device2);
accounts.create(account);
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.build()).join();
final Map<?, ?> accountData = SystemMapper.jsonMapper()
.readValue(response.item().get(Accounts.ATTR_ACCOUNT_DATA).b().asByteArray(), Map.class);
final List<Map<Object, Object>> devices = (List<Map<Object, Object>>) accountData.get("devices");
assertEquals(Integer.valueOf(device2.getId()), devices.get(1).get("id"));
devices.get(1).put("id", Byte.MAX_VALUE + 5);
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().updateItem(UpdateItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression("SET #data = :data")
.expressionAttributeNames(Map.of("#data", Accounts.ATTR_ACCOUNT_DATA))
.expressionAttributeValues(
Map.of(":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(accountData))))
.build()).join();
final CompletionException e = assertThrows(CompletionException.class,
() -> accounts.getByAccountIdentifierAsync(account.getUuid()).join());
Throwable cause = e.getCause();
while (cause.getCause() != null) {
cause = cause.getCause();
}
assertInstanceOf(DeviceIdDeserializer.DeviceIdDeserializationException.class, cause);
}
private static Device generateDevice(byte id) {
return DevicesHelper.createDevice(id);
}
@@ -979,7 +1024,7 @@ class AccountsTest {
}
private static Account generateAccount(String number, UUID uuid, final UUID pni) {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
return generateAccount(number, uuid, pni, List.of(device));
}

View File

@@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -68,7 +69,7 @@ public class ChangeNumberManagerTest {
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@@ -87,7 +88,7 @@ public class ChangeNumberManagerTest {
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@@ -102,7 +103,7 @@ public class ChangeNumberManagerTest {
when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@@ -112,7 +113,8 @@ public class ChangeNumberManagerTest {
when(account.getNumber()).thenReturn("+18005551234");
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
@@ -133,18 +135,21 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@@ -177,19 +182,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@@ -220,19 +229,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@@ -261,18 +274,21 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@@ -301,19 +317,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@@ -338,11 +358,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@@ -350,15 +370,21 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
@@ -371,11 +397,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@@ -383,15 +409,21 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
@@ -404,11 +436,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@@ -416,11 +448,13 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 2, "foo"),
new IncomingMessage(1, 3, 3, "foo"));
new IncomingMessage(1, destinationDeviceId2, 2, "foo"),
new IncomingMessage(1, destinationDeviceId3, 3, "foo"));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, Integer> registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89);
assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), null, null, messages, registrationIds));

View File

@@ -40,7 +40,7 @@ class KeysManagerTest {
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@@ -169,7 +169,8 @@ class KeysManagerTest {
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
@@ -180,10 +181,10 @@ class KeysManagerTest {
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
keysManager.delete(ACCOUNT_UUID).join();
@@ -191,10 +192,10 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
}
@Test
@@ -206,7 +207,8 @@ class KeysManagerTest {
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
@@ -217,10 +219,10 @@ class KeysManagerTest {
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join();
@@ -228,10 +230,10 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
}
@Test
@@ -240,21 +242,29 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent());
final byte deviceId2 = 2;
final byte deviceId3 = 3;
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), (byte) 2,
KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), deviceId3,
KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId(),
"storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
}
@Test
@@ -262,11 +272,14 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), null, null, null,
KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), null,
List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 3), null, null, null, null).join();
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
}

View File

@@ -124,17 +124,19 @@ class MessagePersisterIntegrationTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
messagesCache.insert(messageGuid, account.getUuid(), 1, message);
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message);
expectedMessages.add(message);
}
REDIS_CLUSTER_EXTENSION.getRedisCluster()
.useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY,
String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1)));
String.valueOf(
SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), Device.PRIMARY_ID)) - 1)));
final AtomicBoolean messagesPersisted = new AtomicBoolean(false);
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
messagesManager.addMessageAvailabilityListener(account.getUuid(), Device.PRIMARY_ID,
new MessageAvailabilityListener() {
@Override
public boolean handleNewMessagesAvailable() {
return true;

View File

@@ -9,8 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
@@ -61,7 +61,7 @@ class MessagePersisterTest {
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
private static final byte DESTINATION_DEVICE_ID = 7;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
@@ -90,9 +90,9 @@ class MessagePersisterTest {
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1);
doAnswer(invocation -> {
when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> {
final UUID destinationUuid = invocation.getArgument(0);
final long destinationDeviceId = invocation.getArgument(1);
final byte destinationDeviceId = invocation.getArgument(1);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2);
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
@@ -101,8 +101,8 @@ class MessagePersisterTest {
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get();
}
return null;
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any());
return messages.size();
});
}
@AfterEach
@@ -153,7 +153,7 @@ class MessagePersisterTest {
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
verify(messagesDynamoDb, never()).store(any(), any(), anyByte());
}
@Test
@@ -166,7 +166,7 @@ class MessagePersisterTest {
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
@@ -183,7 +183,7 @@ class MessagePersisterTest {
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyByte());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@@ -219,7 +219,7 @@ class MessagePersisterTest {
setNextSlotToPersist(SlotHash.getSlot(queueName));
// returning `0` indicates something not working correctly
when(messagesManager.persistMessages(any(UUID.class), anyLong(), anyList())).thenReturn(0);
when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenReturn(0);
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class,
@@ -228,22 +228,23 @@ class MessagePersisterTest {
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
while (true) {
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
for (byte deviceId = 1; deviceId < Device.MAXIMUM_DEVICE_ID; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
}
}
throw new IllegalStateException("Could not find a queue name for slot " + slot);
}
private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount,
private void insertMessages(final UUID accountUuid, final byte deviceId, final int messageCount,
final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();

View File

@@ -85,7 +85,7 @@ class MessagesCacheTest {
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7;
private static final byte DESTINATION_DEVICE_ID = 7;
@BeforeEach
void setUp() throws Exception {
@@ -311,7 +311,7 @@ class MessagesCacheTest {
void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
@@ -323,7 +323,7 @@ class MessagesCacheTest {
messagesCache.clear(DESTINATION_UUID, DESTINATION_DEVICE_ID).join();
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(messageCount, get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
assertEquals(messageCount, get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount).size());
}
@ParameterizedTest
@@ -331,7 +331,7 @@ class MessagesCacheTest {
void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
@@ -343,7 +343,7 @@ class MessagesCacheTest {
messagesCache.clear(DESTINATION_UUID).join();
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount));
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount));
}
@Test
@@ -531,7 +531,7 @@ class MessagesCacheTest {
});
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDeviceId,
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
final int messageCount) {
return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
.take(messageCount, true)
@@ -605,7 +605,7 @@ class MessagesCacheTest {
.thenReturn(Flux.from(emptyFinalPagePublisher))
.thenReturn(Flux.empty());
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), 1L);
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID);
// Why initialValue = 3?
// 1. messagesCache.getAllMessages() above produces the first call
@@ -691,7 +691,7 @@ class MessagesCacheTest {
when(asyncCommands.evalsha(any(), any(), any(), any()))
.thenReturn((RedisFuture) removeSuccess);
final Publisher<?> allMessages = messagesCache.get(UUID.randomUUID(), 1L);
final Publisher<?> allMessages = messagesCache.get(UUID.randomUUID(), Device.PRIMARY_ID);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(5));
@@ -752,7 +752,7 @@ class MessagesCacheTest {
.setDestinationUuid(UUID.randomUUID().toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(256))
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
.setSourceUuid(UUID.randomUUID().toString());
}

View File

@@ -98,7 +98,7 @@ class MessagesDynamoDbTest {
@Test
void testSimpleFetchAfterInsert() {
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<MessageProtos.Envelope> messagesStored = load(destinationUuid, destinationDeviceId,
@@ -116,11 +116,12 @@ class MessagesDynamoDbTest {
@ValueSource(ints = {10, 100, 100, 1_000, 3_000})
void testLoadManyAfterInsert(final int messageCount) {
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000,
"message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
@@ -148,18 +149,20 @@ class MessagesDynamoDbTest {
void testLimitedLoad() {
final int messageCount = 200;
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000,
"message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
final int messageLoadLimit = 100;
final int halfOfMessageLoadLimit = messageLoadLimit / 2;
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit);
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
messageLoadLimit);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(10));
@@ -170,7 +173,7 @@ class MessagesDynamoDbTest {
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// the first 100 should be fetched and buffered, but further requests should fail
.then(() -> DYNAMO_DB_EXTENSION.stopServer())
.then(DYNAMO_DB_EXTENSION::stopServer)
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// weve consumed all the buffered messages, so a single request will fail
@@ -183,22 +186,23 @@ class MessagesDynamoDbTest {
void testDeleteForDestination() {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte deviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, deviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid).join();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@@ -206,23 +210,26 @@ class MessagesDynamoDbTest {
void testDeleteForDestinationDevice() {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2).join();
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, destinationDeviceId2).join();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@@ -230,15 +237,17 @@ class MessagesDynamoDbTest {
void testDeleteMessageByDestinationAndGuid() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
final Optional<MessageProtos.Envelope> deletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid(
@@ -247,11 +256,12 @@ class MessagesDynamoDbTest {
assertThat(deletedMessage).isPresent();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
final Optional<MessageProtos.Envelope> alreadyDeletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid(
@@ -266,29 +276,32 @@ class MessagesDynamoDbTest {
void testDeleteSingleMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
messagesDynamoDb.deleteMessage(secondDestinationUuid, Device.PRIMARY_ID,
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
private List<MessageProtos.Envelope> load(final UUID destinationUuid, final long destinationDeviceId,
private List<MessageProtos.Envelope> load(final UUID destinationUuid, final byte destinationDeviceId,
final int count) {
return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count))
.take(count, true)

View File

@@ -34,7 +34,7 @@ class MessagesManagerTest {
final UUID destinationUuid = UUID.randomUUID();
messagesManager.insert(destinationUuid, 1L, message);
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message);
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
@@ -42,7 +42,7 @@ class MessagesManagerTest {
.setSourceUuid(destinationUuid.toString())
.build();
messagesManager.insert(destinationUuid, 1L, syncMessage);
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);
verifyNoMoreInteractions(reportMessageManager);
}

View File

@@ -25,7 +25,7 @@ class RefreshingAccountAndDeviceSupplierTest {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final long deviceId = 2L;
final byte deviceId = 2;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);

View File

@@ -50,8 +50,8 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes
@Test
void storeIfAbsent() {
final UUID identifier = UUID.randomUUID();
final long deviceIdWithExistingKey = 1;
final long deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1;
final byte deviceIdWithExistingKey = 1;
final byte deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1;
final ECSignedPreKey originalSignedPreKey = generateSignedPreKey();

View File

@@ -24,11 +24,11 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
void storeFind() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join());
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), Device.PRIMARY_ID).join());
{
final UUID identifier = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final K signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
@@ -37,14 +37,15 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
{
final UUID identifier = UUID.randomUUID();
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
final byte deviceId2 = 2;
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join());
assertEquals(Optional.of(signedPreKeys.get(1L)), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
assertEquals(Optional.of(signedPreKeys.get(Device.PRIMARY_ID)), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
}
@@ -54,32 +55,33 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
final byte deviceId2 = 2;
{
final UUID identifier = UUID.randomUUID();
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, 1).join();
keys.delete(identifier, Device.PRIMARY_ID).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.empty(), keys.find(identifier, 2).join());
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join());
}
}
}

View File

@@ -5,24 +5,15 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
@@ -37,7 +28,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
@@ -58,7 +49,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
@@ -78,7 +69,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
@@ -90,12 +81,12 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
}
@Test
@@ -103,7 +94,7 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
@@ -115,11 +106,11 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join());
}
}

View File

@@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.tests.util;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
@@ -61,9 +62,9 @@ public class AccountsHelper {
return markStale ? copyAndMarkStale(account) : account;
});
when(mockAccountsManager.updateDevice(any(), anyLong(), any())).thenAnswer(answer -> {
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final Long deviceId = answer.getArgument(1, Long.class);
final byte deviceId = answer.getArgument(1, Byte.class);
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
return markStale ? copyAndMarkStale(account) : account;

View File

@@ -121,12 +121,12 @@ public class AuthHelper {
when(VALID_DEVICE_3_PRIMARY.isPrimary()).thenReturn(true);
when(VALID_DEVICE_3_LINKED.isPrimary()).thenReturn(false);
when(VALID_DEVICE.getId()).thenReturn(1L);
when(VALID_DEVICE_TWO.getId()).thenReturn(1L);
when(DISABLED_DEVICE.getId()).thenReturn(1L);
when(UNDISCOVERABLE_DEVICE.getId()).thenReturn(1L);
when(VALID_DEVICE_3_PRIMARY.getId()).thenReturn(1L);
when(VALID_DEVICE_3_LINKED.getId()).thenReturn(2L);
when(VALID_DEVICE.getId()).thenReturn(Device.PRIMARY_ID);
when(VALID_DEVICE_TWO.getId()).thenReturn(Device.PRIMARY_ID);
when(DISABLED_DEVICE.getId()).thenReturn(Device.PRIMARY_ID);
when(UNDISCOVERABLE_DEVICE.getId()).thenReturn(Device.PRIMARY_ID);
when(VALID_DEVICE_3_PRIMARY.getId()).thenReturn(Device.PRIMARY_ID);
when(VALID_DEVICE_3_LINKED.getId()).thenReturn((byte) 2);
when(VALID_DEVICE.isEnabled()).thenReturn(true);
when(VALID_DEVICE_TWO.isEnabled()).thenReturn(true);
@@ -135,17 +135,17 @@ public class AuthHelper {
when(VALID_DEVICE_3_PRIMARY.isEnabled()).thenReturn(true);
when(VALID_DEVICE_3_LINKED.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT_TWO.getDevice(eq(1L))).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(VALID_ACCOUNT_TWO.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(VALID_ACCOUNT_TWO.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(DISABLED_ACCOUNT.getDevice(eq(1L))).thenReturn(Optional.of(DISABLED_DEVICE));
when(DISABLED_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(DISABLED_DEVICE));
when(DISABLED_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(DISABLED_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(1L))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(VALID_ACCOUNT_3.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
when(VALID_ACCOUNT_3.getDevice((byte) 2)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true);
@@ -212,7 +212,7 @@ public class AuthHelper {
DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter));
}
public static String getAuthHeader(UUID uuid, long deviceId, String password) {
public static String getAuthHeader(UUID uuid, byte deviceId, String password) {
return HeaderUtils.basicAuthHeader(uuid.toString() + "." + deviceId, password);
}
@@ -260,9 +260,9 @@ public class AuthHelper {
when(saltedTokenHash.verify(password)).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(saltedTokenHash);
when(device.isPrimary()).thenReturn(true);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(device.isEnabled()).thenReturn(true);
when(account.getDevice(1L)).thenReturn(Optional.of(device));
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
when(account.getPrimaryDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);

View File

@@ -14,15 +14,15 @@ public class DevicesHelper {
private static final Random RANDOM = new Random();
public static Device createDevice(final long deviceId) {
public static Device createDevice(final byte deviceId) {
return createDevice(deviceId, 0);
}
public static Device createDevice(final long deviceId, final long lastSeen) {
public static Device createDevice(final byte deviceId, final long lastSeen) {
return createDevice(deviceId, lastSeen, 0);
}
public static Device createDevice(final long deviceId, final long lastSeen, final int registrationId) {
public static Device createDevice(final byte deviceId, final long lastSeen, final int registrationId) {
final Device device = new Device();
device.setId(deviceId);
device.setLastSeen(lastSeen);
@@ -34,7 +34,7 @@ public class DevicesHelper {
return device;
}
public static Device createDisabledDevice(final long deviceId, final int registrationId) {
public static Device createDisabledDevice(final byte deviceId, final int registrationId) {
final Device device = new Device();
device.setId(deviceId);
device.setUserAgent("OWT");

View File

@@ -12,7 +12,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
public class MessageHelper {
public static MessageProtos.Envelope createMessage(UUID senderUuid, final int senderDeviceId, UUID destinationUuid,
public static MessageProtos.Envelope createMessage(UUID senderUuid, final byte senderDeviceId, UUID destinationUuid,
long timestamp, String content) {
return MessageProtos.Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())

View File

@@ -35,7 +35,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
@ExtendWith(DropwizardExtensionsSupport.class)
class DestinationDeviceValidatorTest {
static Account mockAccountWithDeviceAndRegId(final Map<Long, Integer> registrationIdsByDeviceId) {
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
final Account account = mock(Account.class);
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
@@ -48,31 +48,34 @@ class DestinationDeviceValidatorTest {
}
static Stream<Arguments> validateRegistrationIdsSource() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF)),
Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 1492),
Set.of(1L)),
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 1492),
Set.of(id1)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 42),
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42)),
Map.of(1L, 0),
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 255)),
Map.of(1L, 0, 2L, 42),
Set.of(2L)),
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
Map.of(id1, 0, id2, 42),
Set.of(id2)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 256)),
Map.of(1L, 41, 2L, 257),
Set.of(1L, 2L))
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
Map.of(id1, 41, id2, 257),
Set.of(id1, id2))
);
}
@@ -80,8 +83,8 @@ class DestinationDeviceValidatorTest {
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Map<Long, Integer> registrationIdsByDeviceId,
Set<Long> expectedStaleDeviceIds) throws Exception {
Map<Byte, Integer> registrationIdsByDeviceId,
Set<Byte> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(
@@ -98,7 +101,7 @@ class DestinationDeviceValidatorTest {
}
}
static Account mockAccountWithDeviceAndEnabled(final Map<Long, Boolean> enabledStateByDeviceId) {
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
@@ -117,51 +120,54 @@ class DestinationDeviceValidatorTest {
}
static Stream<Arguments> validateCompleteDeviceListSource() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 3L),
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id3),
null,
null,
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 2L, 3L),
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id2, id3),
null,
Set.of(2L),
Set.of(id2),
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L),
Set.of(3L),
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1),
Set.of(id3),
null,
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L),
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id2),
Set.of(id3),
Set.of(id2),
Collections.emptySet()),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(1L),
Set.of(3L),
Set.of(1L),
Set.of(1L)
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1),
Set.of(id3),
Set.of(id1),
Set.of(id1)
),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(2L),
Set.of(3L),
Set.of(2L),
Set.of(1L)
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id2),
Set.of(id3),
Set.of(id2),
Set.of(id1)
),
arguments(
mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)),
Set.of(3L),
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id3),
null,
null,
Set.of(1L)
Set.of(id1)
)
);
}
@@ -170,10 +176,10 @@ class DestinationDeviceValidatorTest {
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds,
Set<Long> excludedDeviceIds) throws Exception {
Set<Byte> deviceIds,
Collection<Byte> expectedMissingDeviceIds,
Collection<Byte> expectedExtraDeviceIds,
Set<Byte> excludedDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,

View File

@@ -103,7 +103,7 @@ class WebSocketConnectionIntegrationTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
}
@AfterEach

View File

@@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.nullable;
@@ -162,7 +163,7 @@ class WebSocketConnectionTest {
createMessage(senderOneUuid, accountUuid, 2222, "second"),
createMessage(senderTwoUuid, accountUuid, 3333, "third"));
final long deviceId = 2L;
final byte deviceId = 2;
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
@@ -178,7 +179,7 @@ class WebSocketConnectionTest {
when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn(
when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn(
CompletableFuture.completedFuture(Optional.empty()));
String userAgent = HttpHeaders.USER_AGENT;
@@ -232,10 +233,10 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean()))
.thenReturn(Flux.empty())
.thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")))
.thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")))
@@ -310,7 +311,7 @@ class WebSocketConnectionTest {
final List<Envelope> pendingMessages = List.of(firstMessage, secondMessage);
final long deviceId = 2L;
final byte deviceId = 2;
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
@@ -326,7 +327,7 @@ class WebSocketConnectionTest {
when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn(
when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn(
CompletableFuture.completedFuture(Optional.empty()));
String userAgent = HttpHeaders.USER_AGENT;
@@ -374,14 +375,14 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(
messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false))
messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false))
.thenAnswer(invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@@ -428,7 +429,7 @@ class WebSocketConnectionTest {
}
});
verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyLong(), eq(false));
verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyByte(), eq(false));
}
@Test
@@ -440,7 +441,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
final List<Envelope> firstPageMessages =
@@ -450,10 +451,10 @@ class WebSocketConnectionTest {
final List<Envelope> secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), eq(false)))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), eq(false)))
.thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream())));
when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any()))
when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -492,18 +493,18 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
final UUID senderUuid = UUID.randomUUID();
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false))
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false))
.thenReturn(Flux.fromIterable(messages))
.thenReturn(Flux.empty());
when(messagesManager.delete(eq(accountUuid), eq(1L), any(UUID.class), any()))
when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(UUID.class), any()))
.thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -555,10 +556,10 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean()))
.thenReturn(Flux.empty());
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -583,7 +584,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
final List<Envelope> firstPageMessages =
@@ -593,12 +594,12 @@ class WebSocketConnectionTest {
final List<Envelope> secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean()))
.thenReturn(Flux.fromIterable(firstPageMessages))
.thenReturn(Flux.fromIterable(secondPageMessages))
.thenReturn(Flux.empty());
when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any()))
when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -640,10 +641,10 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean()))
.thenReturn(Flux.empty());
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -672,10 +673,10 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean()))
.thenReturn(Flux.empty());
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -695,7 +696,7 @@ class WebSocketConnectionTest {
void testRetrieveMessageException() {
UUID accountUuid = UUID.randomUUID();
when(device.getId()).thenReturn(2L);
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -725,7 +726,7 @@ class WebSocketConnectionTest {
void testRetrieveMessageExceptionClientDisconnected() {
UUID accountUuid = UUID.randomUUID();
when(device.getId()).thenReturn(2L);
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@@ -748,7 +749,7 @@ class WebSocketConnectionTest {
void testReactivePublisherLimitRate() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 2L;
final byte deviceId = 2;
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
@@ -767,7 +768,7 @@ class WebSocketConnectionTest {
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse));
when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn(
when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn(
CompletableFuture.completedFuture(Optional.empty()));
WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
@@ -798,7 +799,7 @@ class WebSocketConnectionTest {
void testReactivePublisherDisposedWhenConnectionStopped() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 2L;
final byte deviceId = 2;
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
@@ -824,7 +825,7 @@ class WebSocketConnectionTest {
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse));
when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn(
when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn(
CompletableFuture.completedFuture(Optional.empty()));
WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,