Manage device linking tokens transactionally

This commit is contained in:
Jon Chambers
2024-10-07 16:26:11 -04:00
committed by GitHub
parent 42e920cd5c
commit f7aacefc40
18 changed files with 539 additions and 308 deletions

View File

@@ -25,12 +25,10 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -73,11 +71,11 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@@ -102,16 +100,10 @@ class DeviceControllerTest {
private static final byte NEXT_DEVICE_ID = 42;
private static DeviceController deviceController = new DeviceController(
generateLinkDeviceSecret(),
accountsManager,
clientPublicKeysManager,
rateLimiters,
RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration,
testClock);
deviceConfiguration);
@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
@@ -126,10 +118,6 @@ class DeviceControllerTest {
.addResource(deviceController)
.build();
private static byte[] generateLinkDeviceSecret() {
return TestRandomUtil.nextBytes(32);
}
@BeforeEach
void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
@@ -183,12 +171,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
@@ -205,7 +187,9 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
@@ -217,7 +201,7 @@ class DeviceControllerTest {
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@@ -230,7 +214,7 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture());
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any());
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
@@ -241,8 +225,6 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(asyncCommands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
@@ -261,7 +243,7 @@ class DeviceControllerTest {
@MethodSource
void deviceDowngradeDeleteSync(final boolean accountSupportsDeleteSync, final boolean deviceSupportsDeleteSync, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any()))
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class);
@@ -287,7 +269,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsDeleteSync, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -314,7 +298,7 @@ class DeviceControllerTest {
void deviceDowngradeVersionedExpirationTimer(final boolean accountSupportsVersionedExpirationTimer,
final boolean deviceSupportsVersionedExpirationTimer, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any()))
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class);
@@ -340,7 +324,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsVersionedExpirationTimer, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -386,7 +372,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -400,6 +386,52 @@ class DeviceControllerTest {
}
}
@Test
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new LinkDeviceTokenAlreadyUsedException()));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(true, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(403, response.getStatus());
}
}
@Test
void linkDeviceAtomicWithVerificationTokenUsed() {
@@ -427,7 +459,7 @@ class DeviceControllerTest {
when(commands.get(anyString())).thenReturn("");
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -577,16 +609,12 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@@ -614,17 +642,12 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@@ -683,7 +706,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, TestRandomUtil.nextBytes(512), null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -704,12 +727,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -721,16 +738,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, new DeviceCapabilities(true, true, true, false, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty()));
@@ -785,7 +804,7 @@ class DeviceControllerTest {
.get();
assertEquals(411, response.getStatus());
verify(accountsManager, never()).addDevice(any(), any());
verify(accountsManager, never()).addDevice(any(), any(), any());
}
@Test
@@ -898,60 +917,6 @@ class DeviceControllerTest {
}
}
@Test
void checkVerificationToken() {
final UUID uuid = UUID.randomUUID();
assertEquals(Optional.of(uuid),
deviceController.checkVerificationToken(deviceController.generateVerificationToken(uuid)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
testClock.pin(currentTime);
assertEquals(Optional.empty(),
deviceController.checkVerificationToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() {
final Instant tokenTimestamp = testClock.instant();
return Stream.of(
// Expired token
Arguments.of(deviceController.generateVerificationToken(UUID.randomUUID()),
tokenTimestamp.plus(DeviceController.TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
@Test
void setPublicKey() {
final SetPublicKeyRequest request = new SetPublicKeyRequest(Curve.generateKeyPair().getPublicKey());