Remove bulk "set repeated-use signed pre-keys" methods because they were only ever used for single devices

This commit is contained in:
Jon Chambers
2023-12-19 12:21:21 -05:00
committed by Jon Chambers
parent 25c3f55672
commit 057d1f07a8
9 changed files with 46 additions and 108 deletions

View File

@@ -254,11 +254,11 @@ class KeysControllerTest {
when(KEYS.storeKemOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.storePqLastResort(any(), any()))
when(KEYS.storePqLastResort(any(), anyByte(), any()))
.thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
@@ -827,7 +827,7 @@ class KeysControllerTest {
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture());
verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), pqCaptor.capture());
verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey));
verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID, pqLastResortPreKey);
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
@@ -966,7 +966,7 @@ class KeysControllerTest {
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture());
verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), pqCaptor.capture());
verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey));
verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, SAMPLE_DEVICE_ID, pqLastResortPreKey);
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);

View File

@@ -304,7 +304,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
return CompletableFuture.completedFuture(account);
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> ACI_IDENTITY_KEY_PAIR;
@@ -326,12 +326,12 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
switch (identityType) {
case IDENTITY_TYPE_ACI -> {
verify(authenticatedDevice).setSignedPreKey(signedPreKey);
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, signedPreKey);
}
case IDENTITY_TYPE_PNI -> {
verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey);
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID, signedPreKey);
}
}
}
@@ -387,7 +387,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void setLastResortPreKey(final org.signal.chat.common.IdentityType identityType) {
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> ACI_IDENTITY_KEY_PAIR;
@@ -412,7 +412,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw new AssertionError("Bad identity type");
};
verify(keysManager).storePqLastResort(expectedIdentifier, Map.of(AUTHENTICATED_DEVICE_ID, lastResortPreKey));
verify(keysManager).storePqLastResort(expectedIdentifier, AUTHENTICATED_DEVICE_ID, lastResortPreKey);
}
@ParameterizedTest

View File

@@ -1284,7 +1284,7 @@ class AccountsManagerTest {
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3)));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final List<Device> devices = List.of(
DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
@@ -1381,7 +1381,7 @@ class AccountsManagerTest {
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
when(keysManager.getPqEnabledDevices(any())).thenReturn(CompletableFuture.completedFuture(Collections.emptyList()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds);
@@ -1437,8 +1437,8 @@ class AccountsManagerTest {
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
@@ -1500,8 +1500,8 @@ class AccountsManagerTest {
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));

View File

@@ -97,7 +97,7 @@ class KeysManagerTest {
final ECSignedPreKey signedPreKey = generateTestECSignedPreKey(1);
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, signedPreKey)).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, DEVICE_ID, signedPreKey).join();
assertEquals(Optional.of(signedPreKey), keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join());
}
@@ -124,7 +124,7 @@ class KeysManagerTest {
final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey1, preKey2)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(DEVICE_ID, preKeyLast)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, preKeyLast).join();
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
@@ -146,8 +146,8 @@ class KeysManagerTest {
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, deviceId, generateTestECSignedPreKey(keyId++)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, deviceId, generateTestKEMSignedPreKey(keyId++)).join();
}
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
@@ -174,8 +174,8 @@ class KeysManagerTest {
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, deviceId, generateTestECSignedPreKey(keyId++)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, deviceId, generateTestKEMSignedPreKey(keyId++)).join();
}
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
@@ -207,19 +207,17 @@ class KeysManagerTest {
final byte deviceId2 = 2;
final byte deviceId3 = 3;
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), (byte) 2,
KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) 2, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), deviceId3,
KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, deviceId3, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
@@ -227,19 +225,14 @@ class KeysManagerTest {
"storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of()).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing zero last-resort keys should be a no-op");
}
@Test
void testGetPqEnabledDevices() {
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2)).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4)).join();
assertIterableEquals(
Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)),
@@ -250,7 +243,7 @@ class KeysManagerTest {
void testStoreEcSignedPreKeyDisabled() {
when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(false);
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, generateTestECSignedPreKey(1))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, DEVICE_ID, generateTestECSignedPreKey(1)).join();
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
}

View File

@@ -30,27 +30,12 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), Device.PRIMARY_ID).join());
{
final UUID identifier = UUID.randomUUID();
final byte deviceId = 1;
final K signedPreKey = generateSignedPreKey();
final UUID identifier = UUID.randomUUID();
final byte deviceId = 1;
final K signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
}
{
final UUID identifier = UUID.randomUUID();
final byte deviceId2 = 2;
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join());
assertEquals(Optional.of(signedPreKeys.get(Device.PRIMARY_ID)), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
}
@Test
@@ -75,18 +60,16 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
final UUID identifier = UUID.randomUUID();
final byte deviceId2 = 2;
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
final K retainedPreKey = generateSignedPreKey();
keys.store(identifier, signedPreKeys).join();
keys.store(identifier, Device.PRIMARY_ID, generateSignedPreKey()).join();
keys.store(identifier, deviceId2, retainedPreKey).join();
getDynamoDbClient().transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(keys.buildTransactWriteItemForDeletion(identifier, Device.PRIMARY_ID))
.build());
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
assertEquals(Optional.of(retainedPreKey), keys.find(identifier, deviceId2).join());
}
}