Add support for "atomic" device linking/activation

This commit is contained in:
Jon Chambers
2023-05-18 11:40:09 -04:00
committed by Jon Chambers
parent ae7cb8036e
commit b034a088b1
4 changed files with 502 additions and 53 deletions

View File

@@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
@@ -21,10 +22,13 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import javax.ws.rs.Path;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
@@ -35,14 +39,24 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
@@ -56,6 +70,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -121,6 +136,7 @@ class DeviceControllerTest {
when(account.getNextDeviceId()).thenReturn(42L);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.isEnabled()).thenReturn(false);
when(account.isSenderKeySupported()).thenReturn(true);
when(account.isAnnouncementGroupSupported()).thenReturn(true);
@@ -225,6 +241,282 @@ class DeviceControllerTest {
assertEquals(401, response.getStatus());
}
@ParameterizedTest
@MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomic(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(aciIdentityKeyPair));
when(account.getPhoneNumberIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(pniIdentityKeyPair));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
final DeviceResponse response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class);
verify(account).addDevice(deviceCaptor.capture());
final Device device = deviceCaptor.getValue();
assertEquals(aciSignedPreKey.get(), device.getSignedPreKey());
assertEquals(pniSignedPreKey.get(), device.getPhoneNumberIdentitySignedPreKey());
assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),
() -> assertNull(device.getApnId()));
expectedApnsVoipToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getVoipApnId()),
() -> assertNull(device.getVoipApnId()));
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID);
verify(keys).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keys).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";
return Stream.of(
Arguments.of(true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, null)), Optional.empty(), Optional.of(apnsToken), Optional.empty(), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty(), Optional.of(apnsToken), Optional.of(apnsVoipToken), Optional.empty()),
Arguments.of(false, Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken)), Optional.empty(), Optional.empty(), Optional.of(gcmToken))
);
}
@ParameterizedTest
@MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(aciIdentityKeyPair));
when(account.getPhoneNumberIdentityKey()).thenReturn(KeysHelper.serializeIdentityKey(pniIdentityKeyPair));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
}
}
private static Stream<Arguments> linkDeviceAtomicConflictingChannel() {
return Stream.of(
Arguments.of(true, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.of(new GcmRegistrationId("gcm-token"))),
Arguments.of(true, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-token"))),
Arguments.of(true, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.empty()),
Arguments.of(false, Optional.of(new ApnRegistrationId("apns-token", null)), Optional.of(new GcmRegistrationId("gcm-token")))
);
}
@ParameterizedTest
@MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final String aciIdentityKey,
final String pniIdentityKey,
final Optional<SignedPreKey> aciSignedPreKey,
final Optional<SignedPreKey> pniSignedPreKey,
final Optional<SignedPreKey> aciPqLastResortPreKey,
final Optional<SignedPreKey> pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
when(account.getIdentityKey()).thenReturn(aciIdentityKey);
when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(true, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
}
}
private static Stream<Arguments> linkDeviceAtomicMissingProperty() {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Optional<SignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
final Optional<SignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Optional<SignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<SignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final String aciIdentityKey = KeysHelper.serializeIdentityKey(aciIdentityKeyPair);
final String pniIdentityKey = KeysHelper.serializeIdentityKey(pniIdentityKeyPair);
return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, Optional.empty(), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, Optional.empty(), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, Optional.empty())
);
}
@ParameterizedTest
@MethodSource
void linkDeviceAtomicInvalidSignature(final String aciIdentityKey,
final String pniIdentityKey,
final SignedPreKey aciSignedPreKey,
final SignedPreKey pniSignedPreKey,
final SignedPreKey aciPqLastResortPreKey,
final SignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
when(account.getIdentityKey()).thenReturn(aciIdentityKey);
when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(true, 1234, null, null, true, null),
new DeviceActivationRequest(Optional.of(aciSignedPreKey), Optional.of(pniSignedPreKey), Optional.of(aciPqLastResortPreKey), Optional.of(pniPqLastResortPreKey), Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
}
}
private static Stream<Arguments> linkDeviceAtomicInvalidSignature() {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final SignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final SignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final String aciIdentityKey = KeysHelper.serializeIdentityKey(aciIdentityKeyPair);
final String pniIdentityKey = KeysHelper.serializeIdentityKey(pniIdentityKeyPair);
return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, signedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, signedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, signedPreKeyWithBadSignature(pniPqLastResortPreKey))
);
}
private static SignedPreKey signedPreKeyWithBadSignature(final SignedPreKey signedPreKey) {
return new SignedPreKey(signedPreKey.getKeyId(),
signedPreKey.getPublicKey(),
Base64.getEncoder().encodeToString("incorrect-signature".getBytes(StandardCharsets.UTF_8)));
}
@Test
void disabledDeviceRegisterTest() {
Response response = resources.getJerseyTest()

View File

@@ -130,6 +130,7 @@ public class AccountsHelper {
case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
case "getIdentityKey" -> when(updatedAccount.getIdentityKey()).thenAnswer(stubbing);
case "getPhoneNumberIdentityKey" -> when(updatedAccount.getPhoneNumberIdentityKey()).thenAnswer(stubbing);
case "getBadges" -> when(updatedAccount.getBadges()).thenAnswer(stubbing);
case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing);
case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing);