Always require atomic account creation

This commit is contained in:
Jon Chambers
2023-11-11 10:07:16 -08:00
committed by Jon Chambers
parent 9069c5abb6
commit 521900c048
12 changed files with 409 additions and 584 deletions

View File

@@ -288,18 +288,18 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@@ -324,8 +324,8 @@ class DeviceControllerTest {
final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey.get(), device.getSignedPreKey(IdentityType.PNI));
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
@@ -338,14 +338,13 @@ class DeviceControllerTest {
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey));
verify(commands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@@ -368,18 +367,18 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@@ -421,18 +420,18 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@@ -465,10 +464,10 @@ class DeviceControllerTest {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final Optional<ECSignedPreKey> aciSignedPreKey,
final Optional<ECSignedPreKey> pniSignedPreKey,
final Optional<KEMSignedPreKey> aciPqLastResortPreKey,
final Optional<KEMSignedPreKey> pniPqLastResortPreKey) {
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@@ -503,19 +502,19 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Optional<ECSignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
final Optional<ECSignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Optional<KEMSignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<KEMSignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, Optional.empty(), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, Optional.empty(), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, Optional.empty())
Arguments.of(aciIdentityKey, pniIdentityKey, null, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, null, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, null, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, null)
);
}
@@ -545,7 +544,7 @@ class DeviceControllerTest {
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(true, 1234, null, null, true, null),
new DeviceActivationRequest(Optional.of(aciSignedPreKey), Optional.of(pniSignedPreKey), Optional.of(aciPqLastResortPreKey), Optional.of(pniPqLastResortPreKey), Optional.empty(), Optional.empty()));
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")

View File

@@ -6,10 +6,8 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@@ -20,11 +18,11 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
@@ -49,7 +47,6 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.IdentityKey;
@@ -60,6 +57,7 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@@ -118,14 +116,20 @@ class RegistrationControllerTest {
@BeforeEach
void setUp() {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
}
@Test
public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
when(accountsManager.update(any(), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Consumer<Account> accountUpdater = invocation.getArgument(1);
accountUpdater.accept(account);
return invocation.getArgument(0);
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
}
@Test
@@ -151,32 +155,26 @@ class RegistrationControllerTest {
}
@ParameterizedTest
@MethodSource()
@MethodSource
void invalidRegistrationId(Optional<Integer> registrationId, Optional<Integer> pniRegistrationId, int statusCode) throws InterruptedException, JsonProcessingException {
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
.thenReturn(account);
final String recoveryPassword = encodeRecoveryPassword(new byte[0]);
final Map<String, Object> accountAttrs = new HashMap<>();
accountAttrs.put("recoveryPassword", recoveryPassword);
registrationId.ifPresent(id -> accountAttrs.put("registrationId", id));
pniRegistrationId.ifPresent(id -> accountAttrs.put("pniRegistrationId", id));
final String json = SystemMapper.jsonMapper().writeValueAsString(Map.of(
"sessionId", encodeSessionId("sessionId"),
"recoveryPassword", recoveryPassword,
"accountAttributes", accountAttrs,
"skipDeviceTransfer", true
));
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId);
try (Response response = request.post(Entity.json(json))) {
assertEquals(statusCode, response.getStatus());
@@ -292,8 +290,12 @@ class RegistrationControllerTest {
void recoveryPasswordManagerVerificationTrue() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
@@ -340,15 +342,19 @@ class RegistrationControllerTest {
expectedStatus = 409;
} else if (error != null) {
final Exception e = switch (error) {
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null, true);
};
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null, true);
};
doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any());
expectedStatus = error.getExpectedStatus();
} else {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
.thenReturn(createdAccount);
expectedStatus = 200;
}
@@ -396,13 +402,17 @@ class RegistrationControllerTest {
maybeAccount = Optional.empty();
}
when(accountsManager.getByE164(any())).thenReturn(maybeAccount);
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(mock(Account.class));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer, 1, Optional.of(2))))) {
assertEquals(expectedStatus, response.getStatus());
}
}
@@ -415,8 +425,12 @@ class RegistrationControllerTest {
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
@@ -447,22 +461,22 @@ class RegistrationControllerTest {
}
static Stream<Arguments> atomicAccountCreationConflictingChannel() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final IdentityKey aciIdentityKey;
final IdentityKey pniIdentityKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
}
final AccountAttributes fetchesMessagesAccountAttributes =
@@ -477,7 +491,6 @@ class RegistrationControllerTest {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@@ -492,7 +505,6 @@ class RegistrationControllerTest {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@@ -507,7 +519,6 @@ class RegistrationControllerTest {
new byte[0],
pushAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
@@ -539,22 +550,22 @@ class RegistrationControllerTest {
}
static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final IdentityKey aciIdentityKey;
final IdentityKey pniIdentityKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
}
final AccountAttributes accountAttributes =
@@ -566,11 +577,10 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Optional.empty(),
null,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
@@ -581,10 +591,9 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
Optional.empty(),
null,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
@@ -596,13 +605,12 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
Optional.empty(),
null,
Optional.empty(),
Optional.empty())),
@@ -611,12 +619,11 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
Optional.empty(),
null,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty())),
@@ -626,8 +633,7 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
Optional.empty(),
null,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
@@ -641,9 +647,8 @@ class RegistrationControllerTest {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
Optional.empty(),
null,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
@@ -689,13 +694,6 @@ class RegistrationControllerTest {
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
when(accountsManager.update(eq(account), any())).thenAnswer(invocation -> {
final Consumer<Account> accountUpdater = invocation.getArgument(1);
accountUpdater.accept(account);
return invocation.getArgument(0);
});
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
@@ -730,60 +728,23 @@ class RegistrationControllerTest {
() -> verify(device, never()).setGcmId(any()));
}
@ParameterizedTest
@ValueSource(booleans = {false, true})
void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
RegistrationRequest reg = new RegistrationRequest("session-id",
new byte[0],
new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false)),
true,
requireAtomic,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
try (final Response response = request.post(Entity.json(reg))) {
int expected = requireAtomic ? 422 : 200;
assertEquals(expected, response.getStatus());
}
}
private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final IdentityKey aciIdentityKey;
final IdentityKey pniIdentityKey;
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
}
final AccountAttributes fetchesMessagesAccountAttributes =
@@ -796,137 +757,154 @@ class RegistrationControllerTest {
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";
return Stream.of(false, true)
// try with and without strict atomic checking
.flatMap(requireAtomic ->
Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),
return Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty(),
Optional.empty()),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken))));
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken)));
}
/**
* Valid request JSON with the give session ID and skipDeviceTransfer
*/
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer) {
final String rp = encodeRecoveryPassword(recoveryPassword);
return String.format("""
{
"sessionId": "%s",
"recoveryPassword": "%s",
"accountAttributes": {
"recoveryPassword": "%s",
"registrationId": 1
},
"skipDeviceTransfer": %s
}
""", encodeSessionId(sessionId), rp, rp, skipDeviceTransfer);
private static String requestJson(final String sessionId,
final byte[] recoveryPassword,
final boolean skipDeviceTransfer,
final int registrationId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<Integer> pniRegistrationId) {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final AccountAttributes accountAttributes = new AccountAttributes(true, registrationId, "name", "reglock", true,
new Device.DeviceCapabilities(true, true, true, true));
pniRegistrationId.ifPresent(accountAttributes::setPhoneNumberIdentityRegistrationId);
final RegistrationRequest request = new RegistrationRequest(
Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)),
recoveryPassword,
accountAttributes,
skipDeviceTransfer,
aciIdentityKey,
pniIdentityKey,
new DeviceActivationRequest(
KeysHelper.signedECPreKey(1, aciIdentityKeyPair),
KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair),
KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair),
Optional.empty(),
Optional.empty()));
try {
return SystemMapper.jsonMapper().writerWithDefaultPrettyPrinter().writeValueAsString(request);
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
}
/**
* Valid request JSON with the given session ID
*/
private static String requestJson(final String sessionId) {
return requestJson(sessionId, new byte[0], false);
return requestJson(sessionId, new byte[0], false, 1, Optional.of(2));
}
/**
* Valid request JSON with the given Recovery Password
*/
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) {
return requestJson("", recoveryPassword, false);
return requestJson("", recoveryPassword, false, 1, Optional.of(2));
}
/**
@@ -953,12 +931,4 @@ class RegistrationControllerTest {
}
""";
}
private static String encodeSessionId(final String sessionId) {
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
}
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
return Base64.getEncoder().encodeToString(recoveryPassword);
}
}