Create separate key stores for different kinds of pre-keys

This commit is contained in:
Jon Chambers
2023-06-06 17:08:26 -04:00
committed by GitHub
parent cac04146de
commit 2b08742c0a
34 changed files with 1482 additions and 847 deletions

View File

@@ -67,7 +67,7 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@@ -90,7 +90,7 @@ class RegistrationControllerTest {
RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final Keys keys = mock(Keys.class);
private final KeysManager keysManager = mock(KeysManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@@ -105,7 +105,7 @@ class RegistrationControllerTest {
.addResource(
new RegistrationController(accountsManager,
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, keys, rateLimiters))
registrationLockVerificationManager, keysManager, rateLimiters))
.build();
@BeforeEach
@@ -669,8 +669,8 @@ class RegistrationControllerTest {
verify(device).setSignedPreKey(expectedAciSignedPreKey);
verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey);
verify(keys).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey));
verify(keys).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey));
verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey));
verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey));
expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken),
() -> verify(device, never()).setApnId(any()));

View File

@@ -101,7 +101,9 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, deletedAccounts, mock(Keys.class),
accountLockManager,
deletedAccounts,
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@@ -112,7 +112,9 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager, deletedAccounts, mock(Keys.class),
accountLockManager,
deletedAccounts,
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@@ -71,7 +71,7 @@ class AccountsManagerTest {
private Accounts accounts;
private DeletedAccounts deletedAccounts;
private Keys keys;
private KeysManager keysManager;
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private ClientPresenceManager clientPresenceManager;
@@ -94,7 +94,7 @@ class AccountsManagerTest {
void setup() throws InterruptedException {
accounts = mock(Accounts.class);
deletedAccounts = mock(DeletedAccounts.class);
keys = mock(Keys.class);
keysManager = mock(KeysManager.class);
messagesManager = mock(MessagesManager.class);
profilesManager = mock(ProfilesManager.class);
clientPresenceManager = mock(ClientPresenceManager.class);
@@ -157,7 +157,7 @@ class AccountsManagerTest {
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager,
deletedAccounts,
keys,
keysManager,
messagesManager,
profilesManager,
mock(StoredVerificationCodeManager.class),
@@ -542,7 +542,7 @@ class AccountsManagerTest {
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())));
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@@ -565,8 +565,8 @@ class AccountsManagerTest {
verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())));
verify(keys).delete(existingUuid);
verify(keys).delete(phoneNumberIdentifiersByE164.get(e164));
verify(keysManager).delete(existingUuid);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164));
verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
@@ -585,7 +585,7 @@ class AccountsManagerTest {
verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())));
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@@ -646,8 +646,8 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
verify(keys).delete(originalPni);
verify(keys).delete(phoneNumberIdentifiersByE164.get(targetNumber));
verify(keysManager).delete(originalPni);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(targetNumber));
}
@Test
@@ -659,7 +659,7 @@ class AccountsManagerTest {
assertEquals(number, account.getNumber());
verify(deletedAccounts, never()).put(any(), any());
verify(keys, never()).delete(any());
verify(keysManager, never()).delete(any());
}
@Test
@@ -674,7 +674,7 @@ class AccountsManagerTest {
verify(accounts, never()).update(any());
verifyNoInteractions(deletedAccounts);
verifyNoInteractions(keys);
verifyNoInteractions(keysManager);
}
@Test
@@ -697,11 +697,11 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys).delete(originalPni);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verifyNoMoreInteractions(keys);
verify(keysManager).delete(existingAccountUuid);
verify(keysManager).delete(originalPni);
verify(keysManager, atLeastOnce()).delete(targetPni);
verify(keysManager).delete(newPni);
verifyNoMoreInteractions(keysManager);
}
@Test
@@ -723,7 +723,7 @@ class AccountsManagerTest {
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
@@ -735,13 +735,13 @@ class AccountsManagerTest {
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verify(keys).delete(originalPni);
verify(keys).getPqEnabledDevices(uuid);
verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verifyNoMoreInteractions(keys);
verify(keysManager).delete(existingAccountUuid);
verify(keysManager, atLeastOnce()).delete(targetPni);
verify(keysManager).delete(newPni);
verify(keysManager).delete(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verifyNoMoreInteractions(keysManager);
}
@Test
@@ -792,7 +792,7 @@ class AccountsManagerTest {
verify(accounts).update(any());
verifyNoInteractions(deletedAccounts);
verify(keys).delete(oldPni);
verify(keysManager).delete(oldPni);
}
@Test
@@ -813,7 +813,7 @@ class AccountsManagerTest {
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
@@ -839,10 +839,10 @@ class AccountsManagerTest {
verify(accounts).update(any());
verifyNoInteractions(deletedAccounts);
verify(keys).delete(oldPni);
verify(keysManager).delete(oldPni);
// only the pq key for the already-pq-enabled device should be saved
verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
}
@Test

View File

@@ -116,7 +116,7 @@ class AccountsManagerUsernameIntegrationTest {
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
deletedAccounts,
mock(Keys.class),
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@@ -156,7 +156,7 @@ class AccountsTest {
mock(FaultTolerantRedisCluster.class),
mock(AccountLockManager.class),
mock(DeletedAccounts.class),
mock(Keys.class),
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),

View File

@@ -88,44 +88,44 @@ public final class DynamoDbExtensionSchema {
List.of(), List.of()),
EC_KEYS("keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
SingleUsePreKeyStore.KEY_ACCOUNT_UUID,
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_KEYS("pq_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
SingleUsePreKeyStore.KEY_ACCOUNT_UUID,
SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
REPEATED_USE_SIGNED_PRE_KEYS("repeated_use_signed_pre_keys_test",
RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID,
RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeName(RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.attributeName(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID)
.attributeType(ScalarAttributeType.N)
.build()),
List.of(), List.of()),

View File

@@ -0,0 +1,257 @@
/*
* Copyright 2021-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class KeysManagerTest {
private KeysManager keysManager;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_SIGNED_PRE_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@BeforeEach
void setup() {
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Test
void testStore() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testGetCount() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keysManager.delete(ACCOUNT_UUID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID)));
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
}
}

View File

@@ -1,314 +0,0 @@
/*
* Copyright 2021-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
import static org.junit.jupiter.api.Assertions.*;
class KeysTest {
private Keys keys;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@BeforeEach
void setup() {
keys = new Keys(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.PQ_LAST_RESORT_KEYS.tableName());
}
@Test
void testStore() {
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keys.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testGetCount() {
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testDeleteByAccount() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID);
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, getLastResortCount(ACCOUNT_UUID));
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, getLastResortCount(ACCOUNT_UUID));
assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
private int getLastResortCount(UUID uuid) {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(Tables.PQ_LAST_RESORT_KEYS.tableName())
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid)))
.select(Select.COUNT)
.build();
QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest);
return response.count();
}
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null);
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair));
keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair));
keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID)));
}
@Test
void testSortKeyPrefix() {
AttributeValue got = Keys.getSortKeyPrefix(123);
assertArrayEquals(new byte[]{0, 0, 0, 0, 0, 0, 0, 123}, got.b().asByteArray());
}
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
}
}

View File

@@ -0,0 +1,149 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.reactivestreams.Subscriber;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class RepeatedUseSignedPreKeyStoreTest {
private RepeatedUseSignedPreKeyStore keys;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
@BeforeEach
void setUp() {
keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Test
void storeFind() {
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join());
{
final UUID identifier = UUID.randomUUID();
final long deviceId = 1;
final SignedPreKey signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join());
assertEquals(Optional.of(signedPreKeys.get(1L)), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
}
}
@Test
void delete() {
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, 1).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier).join();
assertEquals(Optional.empty(), keys.find(identifier, 1).join());
assertEquals(Optional.empty(), keys.find(identifier, 2).join());
}
}
@Test
void deleteWithError() {
final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class);
final QueryPublisher queryPublisher = mock(QueryPublisher.class);
final SdkPublisher<Map<String, AttributeValue>> itemPublisher = new SdkPublisher<Map<String, AttributeValue>>() {
final Flux<Map<String, AttributeValue>> items = Flux.just(
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)),
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2)));
@Override
public void subscribe(final Subscriber<? super Map<String, AttributeValue>> subscriber) {
items.subscribe(subscriber);
}
};
when(queryPublisher.items()).thenReturn(itemPublisher);
when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher);
final Exception deleteItemException = new IllegalArgumentException("OH NO");
when(mockClient.deleteItem(any(DeleteItemRequest.class)))
.thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build()))
.thenReturn(CompletableFuture.failedFuture(deleteItemException));
final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
final CompletionException completionException =
assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join());
assertEquals(deleteItemException, completionException.getCause());
}
private static SignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.PreKey;
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> {
private SingleUseECPreKeyStore preKeyStore;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.EC_KEYS);
@BeforeEach
void setUp() {
preKeyStore = new SingleUseECPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName());
}
@Override
protected SingleUsePreKeyStore<PreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected PreKey generatePreKey(final long keyId) {
return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize());
}
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey> {
private SingleUseKEMPreKeyStore preKeyStore;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PQ_KEYS);
@BeforeEach
void setUp() {
preKeyStore = new SingleUseKEMPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName());
}
@Override
protected SingleUsePreKeyStore<SignedPreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected SignedPreKey generatePreKey(final long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
}
}

View File

@@ -0,0 +1,155 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
abstract class SingleUsePreKeyStoreTest<K extends PreKey> {
private static final int KEY_COUNT = 100;
protected abstract SingleUsePreKeyStore<K> getPreKeyStore();
protected abstract K generatePreKey(final long keyId);
@Test
void storeTake() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join());
assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join());
assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join());
}
@Test
void getCount() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join());
}
@Test
void deleteSingleDevice() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
}
@Test
void deleteAllDevices() {
final SingleUsePreKeyStore<K> preKeyStore = getPreKeyStore();
final UUID accountIdentifier = UUID.randomUUID();
final long deviceId = 1;
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
final List<K> preKeys = new ArrayList<>(KEY_COUNT);
for (int i = 0; i < KEY_COUNT; i++) {
preKeys.add(generatePreKey(i));
}
preKeyStore.store(accountIdentifier, deviceId, preKeys).join();
preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join();
assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
}
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
}

View File

@@ -23,7 +23,6 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -65,7 +64,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
@@ -82,7 +81,7 @@ class DeviceControllerTest {
public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
Keys keys,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> deviceConfiguration) {
super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration);
@@ -97,7 +96,7 @@ class DeviceControllerTest {
private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessagesManager messagesManager = mock(MessagesManager.class);
private static Keys keys = mock(Keys.class);
private static KeysManager keysManager = mock(KeysManager.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static Account account = mock(Account.class);
@@ -117,7 +116,7 @@ class DeviceControllerTest {
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keys,
keysManager,
rateLimiters,
deviceConfiguration))
.build();
@@ -161,7 +160,7 @@ class DeviceControllerTest {
pendingDevicesManager,
accountsManager,
messagesManager,
keys,
keysManager,
rateLimiters,
rateLimiter,
account,
@@ -314,8 +313,8 @@ class DeviceControllerTest {
verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID);
verify(keys).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keys).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
}
private static Stream<Arguments> linkDeviceAtomic() {
@@ -822,7 +821,7 @@ class DeviceControllerTest {
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
verify(keys).delete(AuthHelper.VALID_UUID, deviceId);
verify(keysManager).delete(AuthHelper.VALID_UUID, deviceId);
}
}

View File

@@ -58,7 +58,7 @@ import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@@ -107,7 +107,7 @@ class KeysControllerTest {
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR);
private final static Keys KEYS = mock(Keys.class );
private final static KeysManager KEYS = mock(KeysManager.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static Account existsAccount = mock(Account.class );