Manage device linking tokens transactionally

This commit is contained in:
Jon Chambers
2024-10-07 16:26:11 -04:00
committed by GitHub
parent 42e920cd5c
commit f7aacefc40
18 changed files with 539 additions and 308 deletions

View File

@@ -25,12 +25,10 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -73,11 +71,11 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@@ -102,16 +100,10 @@ class DeviceControllerTest {
private static final byte NEXT_DEVICE_ID = 42;
private static DeviceController deviceController = new DeviceController(
generateLinkDeviceSecret(),
accountsManager,
clientPublicKeysManager,
rateLimiters,
RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration,
testClock);
deviceConfiguration);
@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
@@ -126,10 +118,6 @@ class DeviceControllerTest {
.addResource(deviceController)
.build();
private static byte[] generateLinkDeviceSecret() {
return TestRandomUtil.nextBytes(32);
}
@BeforeEach
void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
@@ -183,12 +171,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
@@ -205,7 +187,9 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
@@ -217,7 +201,7 @@ class DeviceControllerTest {
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@@ -230,7 +214,7 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture());
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any());
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
@@ -241,8 +225,6 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(asyncCommands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
@@ -261,7 +243,7 @@ class DeviceControllerTest {
@MethodSource
void deviceDowngradeDeleteSync(final boolean accountSupportsDeleteSync, final boolean deviceSupportsDeleteSync, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any()))
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class);
@@ -287,7 +269,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsDeleteSync, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -314,7 +298,7 @@ class DeviceControllerTest {
void deviceDowngradeVersionedExpirationTimer(final boolean accountSupportsVersionedExpirationTimer,
final boolean deviceSupportsVersionedExpirationTimer, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any()))
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class);
@@ -340,7 +324,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsVersionedExpirationTimer, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -386,7 +372,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -400,6 +386,52 @@ class DeviceControllerTest {
}
}
@Test
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new LinkDeviceTokenAlreadyUsedException()));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(true, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(403, response.getStatus());
}
}
@Test
void linkDeviceAtomicWithVerificationTokenUsed() {
@@ -427,7 +459,7 @@ class DeviceControllerTest {
when(commands.get(anyString())).thenReturn("");
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -577,16 +609,12 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@@ -614,17 +642,12 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@@ -683,7 +706,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, TestRandomUtil.nextBytes(512), null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@@ -704,12 +727,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -721,16 +738,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, new DeviceCapabilities(true, true, true, false, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty()));
@@ -785,7 +804,7 @@ class DeviceControllerTest {
.get();
assertEquals(411, response.getStatus());
verify(accountsManager, never()).addDevice(any(), any());
verify(accountsManager, never()).addDevice(any(), any(), any());
}
@Test
@@ -898,60 +917,6 @@ class DeviceControllerTest {
}
}
@Test
void checkVerificationToken() {
final UUID uuid = UUID.randomUUID();
assertEquals(Optional.of(uuid),
deviceController.checkVerificationToken(deviceController.generateVerificationToken(uuid)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
testClock.pin(currentTime);
assertEquals(Optional.empty(),
deviceController.checkVerificationToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() {
final Instant tokenTimestamp = testClock.instant();
return Stream.of(
// Expired token
Arguments.of(deviceController.generateVerificationToken(UUID.randomUUID()),
tokenTimestamp.plus(DeviceController.TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
@Test
void setPublicKey() {
final SetPublicKeyRequest request = new SetPublicKeyRequest(Curve.generateKeyPair().getPublicKey());

View File

@@ -105,7 +105,8 @@ public class AccountCreationDeletionIntegrationTest {
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName());
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor();
@@ -141,6 +142,7 @@ public class AccountCreationDeletionIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messagesManager,
@@ -153,6 +155,7 @@ public class AccountCreationDeletionIntegrationTest {
accountLockExecutor,
clientPresenceExecutor,
CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
}

View File

@@ -98,7 +98,8 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor();
@@ -136,6 +137,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messagesManager,
@@ -148,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accountLockExecutor,
clientPresenceExecutor,
mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
}
}

View File

@@ -93,7 +93,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
{
//noinspection unchecked
@@ -123,6 +124,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(),
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager,
mock(KeysManager.class),
mock(MessagesManager.class),
@@ -135,6 +137,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(Executor.class),
mock(Executor.class),
mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager
);
}

View File

@@ -37,7 +37,9 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
@@ -75,6 +77,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
@@ -89,6 +92,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import javax.crypto.spec.SecretKeySpec;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsManagerTest {
@@ -102,6 +106,10 @@ class AccountsManagerTest {
private static final byte[] ENCRYPTED_USERNAME_1 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_1);
private static final byte[] ENCRYPTED_USERNAME_2 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_2);
private static final byte[] LINK_DEVICE_SECRET = "link-device-secret".getBytes(StandardCharsets.UTF_8);
private static TestClock CLOCK;
private Accounts accounts;
private KeysManager keysManager;
private MessagesManager messagesManager;
@@ -113,7 +121,6 @@ class AccountsManagerTest {
private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private TestClock clock;
private AccountsManager accountsManager;
private SecureValueRecovery2Client svr2Client;
private DynamicConfiguration dynamicConfiguration;
@@ -161,6 +168,7 @@ class AccountsManagerTest {
asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
@@ -220,16 +228,18 @@ class AccountsManagerTest {
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
clock = TestClock.now();
CLOCK = TestClock.now();
final FaultTolerantRedisCluster redisCluster = RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build();
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
redisCluster,
redisCluster,
accountLockManager,
keysManager,
messagesManager,
@@ -241,7 +251,8 @@ class AccountsManagerTest {
clientPublicKeysManager,
mock(Executor.class),
clientPresenceExecutor,
clock,
CLOCK,
LINK_DEVICE_SECRET,
dynamicConfigurationManager);
}
@@ -920,7 +931,7 @@ class AccountsManagerTest {
PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(clock.millis())));
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(CLOCK.millis())));
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
@@ -945,7 +956,7 @@ class AccountsManagerTest {
when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
clock.pin(clock.instant().plusSeconds(60));
CLOCK.pin(CLOCK.instant().plusSeconds(60));
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec(
deviceNameCiphertext,
@@ -960,7 +971,8 @@ class AccountsManagerTest {
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
pniPqLastResortPreKey),
accountsManager.generateDeviceLinkingToken(aci))
.join();
verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId);
@@ -1589,4 +1601,59 @@ class AccountsManagerTest {
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
null);
}
@Test
void checkDeviceLinkingToken() {
final UUID aci = UUID.randomUUID();
assertEquals(Optional.of(aci),
accountsManager.checkDeviceLinkingToken(accountsManager.generateDeviceLinkingToken(aci)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
CLOCK.pin(currentTime);
assertEquals(Optional.empty(), accountsManager.checkDeviceLinkingToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() throws InvalidKeyException {
final Instant tokenTimestamp = Instant.now();
return Stream.of(
// Expired token
Arguments.of(AccountsManager.generateDeviceLinkingToken(UUID.randomUUID(),
new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM),
CLOCK),
tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
}

View File

@@ -14,6 +14,7 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
@@ -105,7 +106,8 @@ class AccountsManagerUsernameIntegrationTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()));
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName()));
final AccountLockManager accountLockManager = mock(AccountLockManager.class);
@@ -135,6 +137,7 @@ class AccountsManagerUsernameIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messageManager,
@@ -147,6 +150,7 @@ class AccountsManagerUsernameIntegrationTest {
Executors.newSingleThreadExecutor(),
Executors.newSingleThreadExecutor(),
mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
}

View File

@@ -106,6 +106,7 @@ class AccountsTest {
Tables.PNI_ASSIGNMENTS,
Tables.USERNAMES,
Tables.DELETED_ACCOUNTS,
Tables.USED_LINK_DEVICE_TOKENS,
// This is an unrelated table used to test "tag-along" transactional updates
Tables.CLIENT_RELEASES);
@@ -132,7 +133,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
}
@Test
@@ -560,7 +562,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
Exception e = TransactionConflictException.builder().build();
e = wrapException ? new CompletionException(e) : e;
@@ -648,7 +651,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
when(dynamoDbAsyncClient.transactWriteItems(any(TransactWriteItemsRequest.class)))
.thenReturn(CompletableFuture.failedFuture(TransactionCanceledException.builder()
@@ -1039,7 +1043,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account);
@@ -1081,7 +1086,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account);
@@ -1181,7 +1187,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account);

View File

@@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
@@ -45,6 +46,7 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK,
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS,
DynamoDbExtensionSchema.Tables.NUMBERS,
DynamoDbExtensionSchema.Tables.PNI,
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS,
@@ -93,7 +95,8 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName());
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor();
@@ -129,6 +132,7 @@ public class AddRemoveDeviceIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messagesManager,
@@ -141,6 +145,7 @@ public class AddRemoveDeviceIntegrationTest {
accountLockExecutor,
clientPresenceExecutor,
CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
}
@@ -182,7 +187,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)))
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join();
assertEquals(2, updatedAccountAndDevice.first().getDevices().size());
@@ -199,6 +205,67 @@ public class AddRemoveDeviceIntegrationTest {
assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
}
@Test
void addDeviceReusedToken() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final String linkDeviceToken = accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI));
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
assertEquals(2,
accountsManager.getByAccountIdentifier(updatedAccountAndDevice.first().getUuid()).orElseThrow().getDevices()
.size());
final CompletionException completionException = assertThrows(CompletionException.class,
() -> accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join());
assertInstanceOf(LinkDeviceTokenAlreadyUsedException.class, completionException.getCause());
assertEquals(2,
accountsManager.getByAccountIdentifier(updatedAccountAndDevice.first().getUuid()).orElseThrow().getDevices()
.size());
}
@Test
void removeDevice() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
@@ -225,7 +292,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)))
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@@ -278,7 +346,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)))
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId();

View File

@@ -372,6 +372,16 @@ public final class DynamoDbExtensionSchema {
List.of(),
List.of()),
USED_LINK_DEVICE_TOKENS("used_link_device_tokens_test",
Accounts.KEY_LINK_DEVICE_TOKEN_HASH,
null,
List.of(AttributeDefinition.builder()
.attributeName(Accounts.KEY_LINK_DEVICE_TOKEN_HASH)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(),
List.of()),
USERNAMES("usernames_test",
Accounts.ATTR_USERNAME_HASH,
null,