Update signed pre-keys in transactions

This commit is contained in:
Jon Chambers
2023-12-05 14:20:16 -05:00
committed by GitHub
parent ede9297139
commit df421e0182
8 changed files with 415 additions and 252 deletions

View File

@@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
@@ -35,6 +34,7 @@ import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
@@ -236,9 +236,27 @@ class KeysControllerTest {
when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(accounts.updateDeviceTransactionallyAsync(any(), anyByte(), any(), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final byte deviceId = invocation.getArgument(1);
final Consumer<Device> deviceUpdater = invocation.getArgument(2);
deviceUpdater.accept(account.getDevice(deviceId).orElseThrow());
return CompletableFuture.completedFuture(account);
});
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.store(any(), anyByte(), any(), any(), any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.storeKemOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.storePqLastResort(any(), any()))
.thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));
@@ -301,8 +319,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(Device.PRIMARY_ID, test));
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any());
}
@Test
@@ -320,8 +337,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey));
verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(Device.PRIMARY_ID, replacementKey));
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any());
}
@Test
@@ -748,13 +764,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture());
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
}
@Test
@@ -782,14 +798,15 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture());
verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), pqCaptor.capture());
verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
}
@Test
@@ -886,13 +903,12 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), eq(signedPreKey),
isNull());
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture());
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
}
@Test
@@ -921,14 +937,15 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture());
verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), pqCaptor.capture());
verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
}
@Test
@@ -967,8 +984,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture());
List<ECPreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -976,6 +992,6 @@ class KeysControllerTest {
assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey());
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.DISABLED_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
}
}

View File

@@ -154,6 +154,7 @@ class AccountsManagerTest {
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.delete(any())).thenReturn(CompletableFuture.completedFuture(null));
doAnswer((Answer<Void>) invocation -> {
@@ -164,7 +165,7 @@ class AccountsManagerTest {
account.setNumber(number, phoneNumberIdentifier);
return null;
}).when(accounts).changeNumber(any(), anyString(), any(), any());
}).when(accounts).changeNumber(any(), anyString(), any(), any(), any());
final SecureStorageClient storageClient = mock(SecureStorageClient.class);
when(storageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
@@ -1163,7 +1164,6 @@ class AccountsManagerTest {
verify(keysManager).delete(originalPni);
verify(keysManager, atLeastOnce()).delete(targetPni);
verify(keysManager).delete(newPni);
verify(keysManager).storeEcSignedPreKeys(eq(newPni), any());
verifyNoMoreInteractions(keysManager);
}
@@ -1209,9 +1209,9 @@ class AccountsManagerTest {
verify(keysManager).delete(newPni);
verify(keysManager).delete(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).storeEcSignedPreKeys(newPni, newSignedKeys);
verify(keysManager).storePqLastResort(eq(newPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verifyNoMoreInteractions(keysManager);
}
@@ -1304,9 +1304,12 @@ class AccountsManagerTest {
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).update(any());
verify(accounts).updateTransactionallyAsync(any(), any());
verify(keysManager).delete(oldPni);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any());
}
@Test
@@ -1360,14 +1363,13 @@ class AccountsManagerTest {
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).update(any());
verify(accounts).updateTransactionallyAsync(any(), any());
verify(keysManager).delete(oldPni);
verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys);
// only the pq key for the already-pq-enabled device should be saved
verify(keysManager).storePqLastResort(eq(oldPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any());
}
@Test
@@ -1420,11 +1422,12 @@ class AccountsManagerTest {
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).update(any());
verify(accounts).updateTransactionallyAsync(any(), any());
verify(keysManager).delete(oldPni);
verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys);
verify(keysManager).storePqLastResort(any(), argThat(Map::isEmpty));
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any());
verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any());
}
@Test

View File

@@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -99,7 +100,10 @@ class AccountsTest {
Tables.NUMBERS,
Tables.PNI_ASSIGNMENTS,
Tables.USERNAMES,
Tables.DELETED_ACCOUNTS);
Tables.DELETED_ACCOUNTS,
// This is an unrelated table used to test "tag-along" transactional updates
Tables.CLIENT_RELEASES);
private final TestClock clock = TestClock.pinned(Instant.EPOCH);
private DynamicConfigurationManager<DynamicConfiguration> mockDynamicConfigManager;
@@ -558,6 +562,71 @@ class AccountsTest {
assertThatThrownBy(() -> accounts.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
}
@Test
void testUpdateTransactionally() {
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID());
createAccount(account);
final String deviceName = "device-name";
assertNotEquals(deviceName,
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName());
assertFalse(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())
.key(Map.of(
ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"),
ClientReleases.ATTR_VERSION, AttributeValues.fromString("test")
))
.build())
.hasItem());
account.getPrimaryDevice().orElseThrow().setName(deviceName);
accounts.updateTransactionallyAsync(account, List.of(TransactWriteItem.builder()
.put(Put.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())
.item(Map.of(
ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"),
ClientReleases.ATTR_VERSION, AttributeValues.fromString("test")
))
.build())
.build())).toCompletableFuture().join();
assertEquals(deviceName,
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName());
assertTrue(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())
.key(Map.of(
ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"),
ClientReleases.ATTR_VERSION, AttributeValues.fromString("test")
))
.build())
.hasItem());
}
@Test
void testUpdateTransactionallyContestedLock() {
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID());
createAccount(account);
account.setVersion(account.getVersion() - 1);
final CompletionException completionException = assertThrows(CompletionException.class,
() -> accounts.updateTransactionallyAsync(account, List.of(TransactWriteItem.builder()
.put(Put.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())
.item(Map.of(
ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"),
ClientReleases.ATTR_VERSION, AttributeValues.fromString("test")
))
.build())
.build())).toCompletableFuture().join());
assertTrue(completionException.getCause() instanceof ContestedOptimisticLockException);
}
@Test
void testGetAll() {
final List<Account> expectedAccounts = new ArrayList<>();
@@ -719,7 +788,7 @@ class AccountsTest {
verifyStoredState(originalNumber, account.getUuid(), account.getPhoneNumberIdentifier(), null, retrieved.get(), account);
}
accounts.changeNumber(account, targetNumber, targetPni, maybeDisplacedAccountIdentifier);
accounts.changeNumber(account, targetNumber, targetPni, maybeDisplacedAccountIdentifier, Collections.emptyList());
assertThat(accounts.getByE164(originalNumber)).isEmpty();
assertThat(accounts.getByAccountIdentifier(originalPni)).isEmpty();
@@ -766,7 +835,7 @@ class AccountsTest {
createAccount(account);
createAccount(existingAccount);
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid())));
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid()), Collections.emptyList()));
assertPhoneNumberConstraintExists(originalNumber, account.getUuid());
assertPhoneNumberIdentifierConstraintExists(originalPni, account.getUuid());
@@ -802,7 +871,7 @@ class AccountsTest {
Map.of(":uuid", AttributeValues.fromUUID(existingAccountIdentifier)))
.build());
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, existingPhoneNumberIdentifier, Optional.empty()));
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, existingPhoneNumberIdentifier, Optional.empty(), Collections.emptyList()));
}
@Test

View File

@@ -20,6 +20,8 @@ import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@@ -65,54 +67,39 @@ class KeysManagerTest {
}
@Test
void testStore() {
void storeEcOneTimePreKeys() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent(),
"Initial last-resort pre-key for an account should be missing");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join();
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Repeatedly storing same key should have no effect");
}
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ prekeys should have no effect on EC prekeys");
@Test
void storeKemOneTimePreKeys() {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Initial pre-key count for an account should be zero");
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Uploading new EC prekeys should have no effect on PQ prekeys");
@Test
void storeEcSignedPreKeys() {
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
final ECSignedPreKey signedPreKey = generateTestECSignedPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)).join();
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, signedPreKey)).join();
assertEquals(Optional.of(signedPreKey), keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
@@ -121,7 +108,8 @@ class KeysManagerTest {
final ECPreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)), null, null, null).join();
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))).join();
final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join();
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
@@ -135,7 +123,8 @@ class KeysManagerTest {
final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2);
final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey1, preKey2)).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(DEVICE_ID, preKeyLast)).join();
assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
@@ -150,79 +139,51 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void testGetCount() {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null).join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
}
@Test
void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6))
.join();
int keyId = 1;
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10))
.join();
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join();
}
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent());
}
keysManager.delete(ACCOUNT_UUID).join();
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent());
}
}
@Test
void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)),
generateTestECSignedPreKey(5),
generateTestKEMSignedPreKey(6))
.join();
int keyId = 1;
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
generateTestKEMSignedPreKey(10))
.join();
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join();
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join();
}
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) {
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent());
}
keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join();
@@ -230,10 +191,11 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join().isPresent());
}
@Test
@@ -272,15 +234,13 @@ class KeysManagerTest {
@Test
void testGetPqEnabledDevices() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2))).join();
keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join();
keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4))).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), null, null, null,
KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), null,
List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 3), null, null, null, null).join();
assertIterableEquals(
Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
@@ -290,21 +250,18 @@ class KeysManagerTest {
void testStoreEcSignedPreKeyDisabled() {
when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(false);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1)),
List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)),
KeysHelper.signedECPreKey(3, identityKeyPair),
KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, generateTestECSignedPreKey(1))).join();
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void buildWriteItemForEcSignedPreKey(final boolean enableSignedPreKeyWrite) {
when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(enableSignedPreKeyWrite);
assertEquals(enableSignedPreKeyWrite,
keysManager.buildWriteItemForEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID, generateTestECSignedPreKey(1)).isPresent());
}
private static ECPreKey generateTestPreKey(final long keyId) {
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
}