Make KeysManager storage/retrieval operations asynchronous

This commit is contained in:
Jon Chambers
2023-06-26 11:17:02 -04:00
committed by Jon Chambers
parent 5847300290
commit f709b00be3
10 changed files with 204 additions and 165 deletions

View File

@@ -645,6 +645,9 @@ class RegistrationControllerTest {
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID();
final Device device = mock(Device.class);
@@ -664,6 +667,8 @@ class RegistrationControllerTest {
return invocation.getArgument(0);
});
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()

View File

@@ -33,6 +33,7 @@ import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -732,7 +733,7 @@ class AccountsManagerTest {
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
@@ -783,6 +784,8 @@ class AccountsManagerTest {
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList()));
final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds);
// non-PNI stuff should not change
@@ -825,7 +828,7 @@ class AccountsManagerTest {
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));

View File

@@ -66,204 +66,208 @@ class KeysManagerTest {
@Test
void testStore() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent(),
"Initial last-resort pre-key for an account should be missing");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId());
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)).join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(),
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join());
final ECPreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)), null, null, null).join();
final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join();
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join());
final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1);
final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2);
final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast).join();
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void testGetCount() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6));
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10));
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10))
.join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
keysManager.delete(ACCOUNT_UUID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1),generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6));
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10));
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10))
.join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent());
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones");
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
}
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join();
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID)));
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
}
@Test
@@ -273,14 +277,15 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1)),
List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)),
KeysHelper.signedECPreKey(3, identityKeyPair),
KeysHelper.signedKEMPreKey(4, identityKeyPair));
List.of(generateTestPreKey(1)),
List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)),
KeysHelper.signedECPreKey(3, identityKeyPair),
KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
}

View File

@@ -27,6 +27,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.ws.rs.Path;
import javax.ws.rs.client.Entity;
@@ -282,6 +283,9 @@ class DeviceControllerTest {
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));

View File

@@ -222,15 +222,16 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@@ -334,7 +335,7 @@ class KeysControllerTest {
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty());
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@@ -520,10 +521,10 @@ class KeysControllerTest {
@Test
void validMultiRequestTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@@ -575,13 +576,15 @@ class KeysControllerTest {
@Test
void validMultiRequestPqTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty());
when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))