Use strongly-typed pre-keys

This commit is contained in:
Jon Chambers
2023-06-09 10:08:49 -04:00
committed by GitHub
parent b27334b0ff
commit 17aa5d8e74
47 changed files with 805 additions and 719 deletions

View File

@@ -74,6 +74,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
@@ -84,7 +85,6 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
@@ -194,7 +194,7 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
}
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);

View File

@@ -55,10 +55,11 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
@@ -418,10 +419,10 @@ class RegistrationControllerTest {
static Stream<Arguments> atomicAccountCreationConflictingChannel() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -507,10 +508,10 @@ class RegistrationControllerTest {
static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -620,10 +621,10 @@ class RegistrationControllerTest {
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final IdentityKey expectedAciIdentityKey,
final IdentityKey expectedPniIdentityKey,
final SignedPreKey expectedAciSignedPreKey,
final SignedPreKey expectedPniSignedPreKey,
final SignedPreKey expectedAciPqLastResortPreKey,
final SignedPreKey expectedPniPqLastResortPreKey,
final ECSignedPreKey expectedAciSignedPreKey,
final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) throws InterruptedException {
@@ -686,10 +687,10 @@ class RegistrationControllerTest {
private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();

View File

@@ -29,7 +29,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@@ -149,17 +149,17 @@ class AccountsManagerChangeNumberIntegrationTest {
final String secondNumber = "+18005552222";
final int rotatedPniRegistrationId = 17;
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final SignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair);
final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair);
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities());
final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>());
account.getMasterDevice().orElseThrow().setSignedPreKey(new SignedPreKey());
account.getMasterDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, ECSignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);

View File

@@ -55,7 +55,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@@ -673,9 +674,10 @@ class AccountsManagerTest {
final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber(
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)),
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any());
@@ -719,10 +721,10 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> newSignedKeys = Map.of(
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
@@ -768,14 +770,14 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Long, SignedPreKey> newSignedKeys = Map.of(
Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -810,10 +812,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> newSignedKeys = Map.of(
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
@@ -823,7 +825,7 @@ class AccountsManagerTest {
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());

View File

@@ -1023,9 +1023,7 @@ class AccountsTest {
assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId());
assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId());
assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen());
assertThat(resultDevice.getSignedPreKey().getPublicKey()).isEqualTo(expectingDevice.getSignedPreKey().getPublicKey());
assertThat(resultDevice.getSignedPreKey().getKeyId()).isEqualTo(expectingDevice.getSignedPreKey().getKeyId());
assertThat(resultDevice.getSignedPreKey().getSignature()).isEqualTo(expectingDevice.getSignedPreKey().getSignature());
assertThat(resultDevice.getSignedPreKey()).isEqualTo(expectingDevice.getSignedPreKey());
assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages());
assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent());
assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName());

View File

@@ -28,11 +28,15 @@ import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
public class ChangeNumberManagerTest {
private AccountsManager accountsManager;
@@ -106,8 +110,9 @@ public class ChangeNumberManagerTest {
void changeNumberSetPrimaryDevicePrekey() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
var prekeys = Map.of(1L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
@@ -133,8 +138,9 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
@@ -176,9 +182,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
@@ -218,9 +225,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
@@ -258,8 +266,9 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
@@ -297,9 +306,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
@@ -344,7 +354,10 @@ public class ChangeNumberManagerTest {
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
final Map<Long, SignedPreKey> preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class,
@@ -374,7 +387,10 @@ public class ChangeNumberManagerTest {
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
final Map<Long, SignedPreKey> preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey());
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class,

View File

@@ -13,14 +13,14 @@ import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
class DeviceTest {
@ParameterizedTest
@MethodSource
void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId,
final SignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) {
final ECSignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) {
final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis();
@@ -41,36 +41,36 @@ class DeviceTest {
// master fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled
Arguments.of(true, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true)
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true)
);
}
}

View File

@@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.security.SecureRandom;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -21,8 +20,9 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@@ -37,6 +37,8 @@ class KeysManagerTest {
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@BeforeEach
void setup() {
keysManager = new KeysManager(
@@ -62,17 +64,17 @@ class KeysManagerTest {
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestKEMSignedPreKey(1001));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
@@ -80,7 +82,7 @@ class KeysManagerTest {
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
@@ -88,13 +90,12 @@ class KeysManagerTest {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(1002));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@@ -102,10 +103,10 @@ class KeysManagerTest {
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1);
final ECPreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@@ -114,9 +115,9 @@ class KeysManagerTest {
void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1);
final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2);
final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
@@ -138,7 +139,7 @@ class KeysManagerTest {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@@ -147,13 +148,11 @@ class KeysManagerTest {
void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@@ -176,13 +175,11 @@ class KeysManagerTest {
void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@@ -211,17 +208,17 @@ class KeysManagerTest {
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones");
}
@Test
@@ -237,21 +234,15 @@ class KeysManagerTest {
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID)));
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
private static ECPreKey generateTestPreKey(final long keyId) {
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
private static ECSignedPreKey generateTestECSignedPreKey(final long keyId) {
return KeysHelper.signedECPreKey(keyId, IDENTITY_KEY_PAIR);
}
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
private static KEMSignedPreKey generateTestKEMSignedPreKey(final long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
}
}

View File

@@ -0,0 +1,44 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import static org.junit.jupiter.api.Assertions.*;
class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest<KEMSignedPreKey> {
private RepeatedUseKEMSignedPreKeyStore keyStore;
private int currentKeyId = 1;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@BeforeEach
void setUp() {
keyStore = new RepeatedUseKEMSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Override
protected RepeatedUseSignedPreKeyStore<KEMSignedPreKey> getKeyStore() {
return keyStore;
}
@Override
protected KEMSignedPreKey generateSignedPreKey() {
return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR);
}
}

View File

@@ -5,59 +5,31 @@
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.reactivestreams.Subscriber;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
class RepeatedUseSignedPreKeyStoreTest {
protected abstract RepeatedUseSignedPreKeyStore<K> getKeyStore();
private RepeatedUseSignedPreKeyStore keys;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
@BeforeEach
void setUp() {
keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
protected abstract K generateSignedPreKey();
@Test
void storeFind() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join());
{
final UUID identifier = UUID.randomUUID();
final long deviceId = 1;
final SignedPreKey signedPreKey = generateSignedPreKey();
final K signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
@@ -65,7 +37,7 @@ class RepeatedUseSignedPreKeyStoreTest {
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
@@ -78,11 +50,13 @@ class RepeatedUseSignedPreKeyStoreTest {
@Test
void delete() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
@@ -96,7 +70,7 @@ class RepeatedUseSignedPreKeyStoreTest {
{
final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of(
final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(),
2L, generateSignedPreKey()
);
@@ -108,42 +82,4 @@ class RepeatedUseSignedPreKeyStoreTest {
assertEquals(Optional.empty(), keys.find(identifier, 2).join());
}
}
@Test
void deleteWithError() {
final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class);
final QueryPublisher queryPublisher = mock(QueryPublisher.class);
final SdkPublisher<Map<String, AttributeValue>> itemPublisher = new SdkPublisher<Map<String, AttributeValue>>() {
final Flux<Map<String, AttributeValue>> items = Flux.just(
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)),
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2)));
@Override
public void subscribe(final Subscriber<? super Map<String, AttributeValue>> subscriber) {
items.subscribe(subscriber);
}
};
when(queryPublisher.items()).thenReturn(itemPublisher);
when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher);
final Exception deleteItemException = new IllegalArgumentException("OH NO");
when(mockClient.deleteItem(any(DeleteItemRequest.class)))
.thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build()))
.thenReturn(CompletableFuture.failedFuture(deleteItemException));
final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
final CompletionException completionException =
assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join());
assertEquals(deleteItemException, completionException.getCause());
}
private static SignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
}
}

View File

@@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> {
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
private SingleUseECPreKeyStore preKeyStore;
@@ -24,12 +24,12 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> {
}
@Override
protected SingleUsePreKeyStore<PreKey> getPreKeyStore() {
protected SingleUsePreKeyStore<ECPreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected PreKey generatePreKey(final long keyId) {
return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize());
protected ECPreKey generatePreKey(final long keyId) {
return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
}
}

View File

@@ -9,10 +9,10 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey> {
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreKey> {
private SingleUseKEMPreKeyStore preKeyStore;
@@ -28,12 +28,12 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey>
}
@Override
protected SingleUsePreKeyStore<SignedPreKey> getPreKeyStore() {
protected SingleUsePreKeyStore<KEMSignedPreKey> getPreKeyStore() {
return preKeyStore;
}
@Override
protected SignedPreKey generatePreKey(final long keyId) {
protected KEMSignedPreKey generatePreKey(final long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
}
}

View File

@@ -24,7 +24,7 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
abstract class SingleUsePreKeyStoreTest<K extends PreKey> {
abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
private static final int KEY_COUNT = 100;

View File

@@ -54,9 +54,10 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
@@ -265,10 +266,10 @@ class DeviceControllerTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -351,10 +352,10 @@ class DeviceControllerTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey;
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@@ -395,10 +396,10 @@ class DeviceControllerTest {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final Optional<SignedPreKey> aciSignedPreKey,
final Optional<SignedPreKey> pniSignedPreKey,
final Optional<SignedPreKey> aciPqLastResortPreKey,
final Optional<SignedPreKey> pniPqLastResortPreKey) {
final Optional<ECSignedPreKey> aciSignedPreKey,
final Optional<ECSignedPreKey> pniSignedPreKey,
final Optional<KEMSignedPreKey> aciPqLastResortPreKey,
final Optional<KEMSignedPreKey> pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@@ -435,10 +436,10 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Optional<SignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
final Optional<SignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Optional<SignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<SignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Optional<ECSignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
final Optional<ECSignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Optional<KEMSignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<KEMSignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
@@ -455,10 +456,10 @@ class DeviceControllerTest {
@MethodSource
void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final SignedPreKey aciSignedPreKey,
final SignedPreKey pniSignedPreKey,
final SignedPreKey aciPqLastResortPreKey,
final SignedPreKey pniPqLastResortPreKey) {
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@@ -495,25 +496,31 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final SignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final SignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, signedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, signedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, signedPreKeyWithBadSignature(pniPqLastResortPreKey))
Arguments.of(aciIdentityKey, pniIdentityKey, ecSignedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, ecSignedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, kemSignedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, kemSignedPreKeyWithBadSignature(pniPqLastResortPreKey))
);
}
private static SignedPreKey signedPreKeyWithBadSignature(final SignedPreKey signedPreKey) {
return new SignedPreKey(signedPreKey.getKeyId(),
signedPreKey.getPublicKey(),
private static ECSignedPreKey ecSignedPreKeyWithBadSignature(final ECSignedPreKey signedPreKey) {
return new ECSignedPreKey(signedPreKey.keyId(),
signedPreKey.publicKey(),
"incorrect-signature".getBytes(StandardCharsets.UTF_8));
}
private static KEMSignedPreKey kemSignedPreKeyWithBadSignature(final KEMSignedPreKey signedPreKey) {
return new KEMSignedPreKey(signedPreKey.keyId(),
signedPreKey.publicKey(),
"incorrect-signature".getBytes(StandardCharsets.UTF_8));
}

View File

@@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
@@ -20,6 +21,8 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@@ -47,6 +50,9 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
@@ -63,6 +69,7 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
@@ -86,27 +93,27 @@ class KeysControllerTest {
private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey());
private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234);
private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667);
private final PreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334);
private final PreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336);
private final ECPreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234);
private final ECPreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667);
private final ECPreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334);
private final ECPreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336);
private final PreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777);
private final ECPreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777);
private final SignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR);
private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR);
private final ECSignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR);
private final ECSignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR);
private final static KeysManager KEYS = mock(KeysManager.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
@@ -127,6 +134,42 @@ class KeysControllerTest {
private Device sampleDevice;
private record WeaklyTypedPreKey(long keyId,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] publicKey) {
static WeaklyTypedPreKey fromPreKey(final PreKey<?> preKey) {
return new WeaklyTypedPreKey(preKey.keyId(), preKey.serializedPublicKey());
}
}
private record WeaklyTypedSignedPreKey(long keyId,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] publicKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] signature) {
static WeaklyTypedSignedPreKey fromSignedPreKey(final SignedPreKey<?> signedPreKey) {
return new WeaklyTypedSignedPreKey(signedPreKey.keyId(), signedPreKey.serializedPublicKey(), signedPreKey.signature());
}
}
private record WeaklyTypedPreKeyState(List<WeaklyTypedPreKey> preKeys,
WeaklyTypedSignedPreKey signedPreKey,
List<WeaklyTypedSignedPreKey> pqPreKeys,
WeaklyTypedSignedPreKey pqLastResortPreKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] identityKey) {
}
@BeforeEach
void setup() {
sampleDevice = mock(Device.class);
@@ -228,30 +271,30 @@ class KeysControllerTest {
@Test
void getSignedPreKeyV2() {
SignedPreKey result = resources.getJerseyTest()
ECSignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
.get(ECSignedPreKey.class);
assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result);
assertEquals(VALID_DEVICE_SIGNED_KEY, result);
}
@Test
void getPhoneNumberIdentifierSignedPreKeyV2() {
SignedPreKey result = resources.getJerseyTest()
ECSignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.queryParam("identity", "pni")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
.get(ECSignedPreKey.class);
assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result);
assertEquals(VALID_DEVICE_PNI_SIGNED_KEY, result);
}
@Test
void putSignedPreKeyV2() {
SignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR);
ECSignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -267,7 +310,7 @@ class KeysControllerTest {
@Test
void putPhoneNumberIdentitySignedPreKeyV2() {
final SignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR);
final ECSignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -285,7 +328,7 @@ class KeysControllerTest {
@Test
void disabledPutSignedPreKeyV2() {
SignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR);
ECSignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -305,10 +348,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
@@ -316,7 +359,7 @@ class KeysControllerTest {
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.<SignedPreKey>empty());
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@@ -327,10 +370,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
@@ -348,10 +391,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
@@ -368,10 +411,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
@@ -388,10 +431,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1);
@@ -410,10 +453,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
@@ -445,9 +488,9 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
@@ -510,14 +553,14 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey);
assertEquals(SAMPLE_KEY, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
@@ -525,9 +568,9 @@ class KeysControllerTest {
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
assertKeysMatch(SAMPLE_KEY2, preKey);
assertEquals(SAMPLE_KEY2, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
@@ -535,7 +578,7 @@ class KeysControllerTest {
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey);
assertEquals(SAMPLE_KEY4, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
@@ -554,7 +597,7 @@ class KeysControllerTest {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.<SignedPreKey>empty());
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty());
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@@ -566,16 +609,16 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey);
assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey);
assertEquals(SAMPLE_KEY, preKey);
assertEquals(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
@@ -585,9 +628,9 @@ class KeysControllerTest {
deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey).isNull();
assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey);
assertEquals(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
@@ -596,7 +639,7 @@ class KeysControllerTest {
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey);
assertEquals(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
@@ -656,9 +699,9 @@ class KeysControllerTest {
@Test
void putKeysTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@@ -672,7 +715,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@@ -684,11 +727,11 @@ class KeysControllerTest {
@Test
void putKeysPqTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@@ -702,8 +745,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
@@ -718,8 +761,9 @@ class KeysControllerTest {
void putKeysStructurallyInvalidSignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null);
final KEMSignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair);
final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, WeaklyTypedSignedPreKey.fromSignedPreKey(wrongPreKey), null, null, identityKey.serialize());
Response response =
resources.getJerseyTest()
@@ -728,15 +772,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
assertThat(response.getStatus()).isEqualTo(400);
}
@Test
void putKeysStructurallyInvalidUnsignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes());
final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null);
final WeaklyTypedPreKey wrongPreKey = new WeaklyTypedPreKey(1, "cluck cluck i'm a parrot".getBytes());
final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(List.of(wrongPreKey), null, null, null, identityKey.serialize());
Response response =
resources.getJerseyTest()
@@ -745,15 +790,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
assertThat(response.getStatus()).isEqualTo(400);
}
@Test
void putKeysStructurallyInvalidPQOneTimeKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null);
final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair));
final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, null, List.of(wrongPreKey), null, identityKey.serialize());
Response response =
resources.getJerseyTest()
@@ -762,15 +808,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
assertThat(response.getStatus()).isEqualTo(400);
}
@Test
void putKeysStructurallyInvalidPQLastResortKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey);
final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair));
final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, null, null, wrongPreKey, identityKey.serialize());
Response response =
resources.getJerseyTest()
@@ -779,14 +826,14 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
assertThat(response.getStatus()).isEqualTo(400);
}
@Test
void putKeysByPhoneNumberIdentifierTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@@ -801,7 +848,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@@ -813,11 +860,11 @@ class KeysControllerTest {
@Test
void putKeysByPhoneNumberIdentifierPqTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@@ -832,8 +879,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
@@ -846,7 +893,7 @@ class KeysControllerTest {
@Test
void putPrekeyWithInvalidSignature() {
final SignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair());
final ECSignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair());
PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, badSignedPreKey, List.of());
Response response =
resources.getJerseyTest()
@@ -861,9 +908,9 @@ class KeysControllerTest {
@Test
void disabledPutKeysTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@@ -877,13 +924,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue();
List<ECPreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey()).isEqualTo(preKey.getPublicKey());
assertThat(capturedList.get(0).keyId()).isEqualTo(31337);
assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey());
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
@@ -892,10 +939,10 @@ class KeysControllerTest {
@Test
void putIdentityKeyNonPrimary() {
final PreKey preKey = KeysHelper.ecPreKey(31337);
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR);
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR);
List<PreKey> preKeys = List.of(preKey);
List<ECPreKey> preKeys = List.of(preKey);
PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, signedPreKey, preKeys);
@@ -908,13 +955,4 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(403);
}
private void assertKeysMatch(PreKey expected, PreKey actual) {
assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId());
assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey());
if (expected instanceof final SignedPreKey signedExpected) {
final SignedPreKey signedActual = (SignedPreKey) actual;
assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature());
}
}
}

View File

@@ -12,7 +12,8 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import java.util.Base64;
@@ -22,7 +23,7 @@ class PreKeyTest {
@Test
void serializeToJSONV2() throws Exception {
PreKey preKey = new PreKey(1234, PUBLIC_KEY);
ECPreKey preKey = new ECPreKey(1234, new ECPublicKey(PUBLIC_KEY));
assertThat("PreKeyV2 Serialization works",
asJson(preKey),

View File

@@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.util;
import java.util.Random;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util;

View File

@@ -7,26 +7,29 @@ package org.whispersystems.textsecuregcm.tests.util;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.kem.KEMKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyType;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
public final class KeysHelper {
public static PreKey ecPreKey(final long id) {
return new PreKey(id, Curve.generateKeyPair().getPublicKey().serialize());
public static ECPreKey ecPreKey(final long id) {
return new ECPreKey(id, Curve.generateKeyPair().getPublicKey());
}
public static SignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = Curve.generateKeyPair().getPublicKey().serialize();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey);
return new SignedPreKey(id, pubKey, sig);
public static ECSignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) {
final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new ECSignedPreKey(id, pubKey, sig);
}
public static SignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey().serialize();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey);
return new SignedPreKey(id, pubKey, sig);
public static KEMSignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new KEMSignedPreKey(id, pubKey, sig);
}
}