Add devices to accounts transactionally

This commit is contained in:
Jon Chambers
2023-12-07 11:19:40 -05:00
committed by GitHub
parent e084a9f2b6
commit 50d92265ea
10 changed files with 520 additions and 268 deletions

View File

@@ -24,6 +24,7 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
@@ -38,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@@ -72,12 +74,15 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@@ -91,6 +96,7 @@ class DeviceControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
private static RedisAdvancedClusterAsyncCommands<String, String> asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
private static Account account = mock(Account.class);
private static Account maxedAccount = mock(Account.class);
private static Device primaryDevice = mock(Device.class);
@@ -106,7 +112,10 @@ class DeviceControllerTest {
messagesManager,
keysManager,
rateLimiters,
RedisClusterHelper.builder().stringCommands(commands).build(),
RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration,
testClock);
@@ -114,6 +123,7 @@ class DeviceControllerTest {
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
@@ -166,6 +176,7 @@ class DeviceControllerTest {
rateLimiters,
rateLimiter,
commands,
asyncCommands,
account,
maxedAccount,
primaryDevice,
@@ -300,11 +311,22 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
final DeviceResponse response = resources.getJerseyTest()
@@ -315,10 +337,10 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class);
verify(account).addDevice(deviceCaptor.capture());
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture());
final Device device = deviceCaptor.getValue();
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
@@ -333,14 +355,9 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey));
verify(commands).set(anyString(), anyString(), any());
verify(asyncCommands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@@ -596,9 +613,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty()));
@@ -719,35 +745,66 @@ class DeviceControllerTest {
verifyNoMoreInteractions(messagesManager);
}
@Test
void deviceDowngradePniTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true,
false, true);
AccountAttributes accountAttributes =
new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities);
@ParameterizedTest
@MethodSource
void deviceDowngradePniTest(final boolean accountSupportsPni, final boolean deviceSupportsPni, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID);
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
Response response = resources
.getJerseyTest()
.target("/v1/devices/" + verificationToken)
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(account.isPniSupported()).thenReturn(accountSupportsPni);
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, deviceSupportsPni, true));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
deviceCapabilities = new DeviceCapabilities(true, true, true, true);
accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities);
response = resources
.getJerseyTest()
.target("/v1/devices/" + verificationToken)
.request()
.header("Authorization",
AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
assertEquals(expectedStatus, response.getStatus());
}
}
private static List<Arguments> deviceDowngradePniTest() {
return List.of(
Arguments.of(true, true, 200),
Arguments.of(true, false, 409),
Arguments.of(false, true, 200),
Arguments.of(false, false, 200));
}
@Test

View File

@@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@@ -74,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@@ -167,7 +167,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId.orElse(0));
@@ -290,7 +290,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@@ -348,7 +348,7 @@ class RegistrationControllerTest {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount);
expectedStatus = 200;
@@ -402,7 +402,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@@ -426,7 +426,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@@ -658,16 +658,10 @@ class RegistrationControllerTest {
@ParameterizedTest
@MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final IdentityKey expectedAciIdentityKey,
final IdentityKey expectedPniIdentityKey,
final ECSignedPreKey expectedAciSignedPreKey,
final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<ApnRegistrationId> expectedApnRegistrationId,
final Optional<GcmRegistrationId> expectedGcmRegistrationId) throws InterruptedException {
final DeviceSpec expectedDeviceSpec) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
@@ -685,7 +679,7 @@ class RegistrationControllerTest {
when(a.getPrimaryDevice()).thenReturn(device);
});
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@@ -699,18 +693,11 @@ class RegistrationControllerTest {
verify(accountsManager).create(
eq(NUMBER),
eq(PASSWORD),
isNull(),
argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())),
eq(Collections.emptyList()),
eq(expectedAciIdentityKey),
eq(expectedPniIdentityKey),
eq(expectedAciSignedPreKey),
eq(expectedPniSignedPreKey),
eq(expectedAciPqLastResortPreKey),
eq(expectedPniPqLastResortPreKey),
eq(expectedApnRegistrationId),
eq(expectedGcmRegistrationId));
eq(expectedDeviceSpec));
}
private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) {
@@ -745,11 +732,17 @@ class RegistrationControllerTest {
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
}
final byte[] deviceName = "test".getBytes(StandardCharsets.UTF_8);
final int registrationId = 1;
final int pniRegistrationId = 2;
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(false, false, false, false);
final AccountAttributes fetchesMessagesAccountAttributes =
new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final AccountAttributes pushAccountAttributes =
new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
new AccountAttributes(false, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@@ -771,13 +764,20 @@ class RegistrationControllerTest {
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty(),
Optional.empty()),
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey)),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
@@ -794,36 +794,22 @@ class RegistrationControllerTest {
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
false,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
pniPqLastResortPreKey)),
// Fetches messages; no push tokens
// Has GCM token
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
@@ -838,12 +824,21 @@ class RegistrationControllerTest {
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))));
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
false,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken)),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
);
}
/**