diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 00eb7a4e3..b77f45f1b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -34,7 +34,6 @@ import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; -import java.io.IOException; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -57,7 +56,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.GroupSendTokenHeader; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.entities.CheckKeysRequest; -import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; @@ -92,7 +90,6 @@ public class KeysController { private final ServerSecretParams serverSecretParams; private final Clock clock; - private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys"); private static final String STORE_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "storeKeys"); private static final String STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME = MetricsUtil.name(KeysController.class, "storeKeyBundleSize"); @@ -395,51 +392,14 @@ public class KeysController { final List devices = parseDeviceId(deviceId, target); - final List responseItems = Flux.fromIterable(devices) - .flatMap(device -> Mono.zip( - Mono.just(device), - Mono.fromFuture(keysManager.takeEC(targetIdentifier.uuid(), device.getId())), - Mono.fromFuture(keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())), - Mono.fromFuture(keysManager.takePQ(targetIdentifier.uuid(), device.getId())))) - .flatMap(deviceAndPreKeys -> { - final Device device = deviceAndPreKeys.getT1(); - final KEMSignedPreKey pqPreKey = deviceAndPreKeys.getT4().orElse(null); - final ECPreKey unsignedEcPreKey = deviceAndPreKeys.getT2().orElse(null); - final ECSignedPreKey signedEcPreKey = deviceAndPreKeys.getT3().orElse(null); - final int registrationId = device.getRegistrationId(targetIdentifier.identityType()); - - Metrics.counter(GET_KEYS_COUNTER_NAME, Tags.of( - UserAgentTagUtil.getPlatformTag(userAgent), - Tag.of(IDENTITY_TYPE_TAG_NAME, targetIdentifier.identityType().name()), - Tag.of("oneTimeEcKeyAvailable", String.valueOf(unsignedEcPreKey != null)), - Tag.of("signedEcKeyAvailable", String.valueOf(signedEcPreKey != null)), - Tag.of("pqKeyAvailable", String.valueOf(pqPreKey != null)))) - .increment(); - - if (pqPreKey == null) { - // The PQ prekey should never be null. This should only happen if the account or device has been - // removed. - return Mono.fromCompletionStage(() -> accounts.getByServiceIdentifierAsync(targetIdentifier)) - .flatMap(maybeAccount -> maybeAccount - .flatMap(rereadAccount -> rereadAccount.getDevice(device.getId())) - .filter(rereadDevice -> - registrationId == rereadDevice.getRegistrationId(targetIdentifier.identityType())) - .map(rereadDevice -> { - // The account and device still exist, and the device we originally read matches the current - // registrationId, so the lastResort key should have existed - log.error( - "Target {}, Account {}, DeviceId {}, RegistrationId {} was missing a last resort prekey", - targetIdentifier, - target.getIdentifier(IdentityType.ACI), - rereadDevice.getId(), - rereadDevice.getRegistrationId(targetIdentifier.identityType())); - return Mono.error(new IOException("Device missing last resort prekey")); - }) - .orElse(Mono.empty())); - } - return Mono.just(new PreKeyResponseItem( - device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey, pqPreKey)); - }) + final List responseItems = Flux.fromIterable(devices).flatMap(device -> Mono + .fromCompletionStage(keysManager.takeDevicePreKeys(device.getId(), targetIdentifier, userAgent)) + .flatMap(Mono::justOrEmpty) + .map(devicePreKeys -> new PreKeyResponseItem( + device.getId(), device.getRegistrationId(targetIdentifier.identityType()), + devicePreKeys.ecSignedPreKey(), + devicePreKeys.ecPreKey().orElse(null), + devicePreKeys.kemSignedPreKey()))) .collectList() .block(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java index 68d1680fd..7eb3536c9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java @@ -19,7 +19,7 @@ public class PreKeyResponseItem { private int registrationId; @JsonProperty - @Schema(description="the signed elliptic-curve prekey for the device, if one has been set") + @Schema(description="the signed elliptic-curve prekey for the device") private ECSignedPreKey signedPreKey; @JsonProperty @@ -28,7 +28,7 @@ public class PreKeyResponseItem { @JsonProperty @Schema(description="a signed post-quantum prekey for the device " + - "(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)") + "(a one-time prekey if any remain, otherwise the last-resort prekey)") private KEMSignedPreKey pqPreKey; public PreKeyResponseItem() {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java index 9e35b55c8..247e11761 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java @@ -56,7 +56,7 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier); yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND) - .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)); + .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)); } catch (final StatusException e) { yield Mono.error(e); } @@ -66,7 +66,7 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony lookUpAccount(serviceIdentifier, Status.UNAUTHENTICATED) .flatMap(targetAccount -> UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray()) - ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager) + ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager) : Mono.error(Status.UNAUTHENTICATED.asException())); default -> Mono.error(Status.INVALID_ARGUMENT.asException()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 4b5ec7dc0..423c4a99d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -11,10 +11,7 @@ import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.KemSignedPreKey; import org.signal.chat.keys.GetPreKeysResponse; -import org.whispersystems.textsecuregcm.entities.ECPreKey; -import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; -import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; -import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; @@ -28,7 +25,7 @@ class KeysGrpcHelper { static final byte ALL_DEVICES = 0; static Mono getPreKeys(final Account targetAccount, - final IdentityType identityType, + final ServiceIdentifier targetServiceIdentifier, final byte targetDeviceId, final KeysManager keysManager) { @@ -36,42 +33,37 @@ class KeysGrpcHelper { ? Flux.fromIterable(targetAccount.getDevices()) : Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId))); + final String userAgent = RequestAttributesUtil.getUserAgent().orElse(null); return devices - .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) - .flatMap(device -> Flux.merge( - Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) + .flatMap(device -> Mono + .fromFuture(keysManager.takeDevicePreKeys(device.getId(), targetServiceIdentifier, userAgent)) .flatMap(Mono::justOrEmpty) - .reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> { - if (preKey instanceof ECPreKey ecPreKey) { - builder.setEcOneTimePreKey(EcPreKey.newBuilder() + .map(devicePreKeys -> { + final GetPreKeysResponse.PreKeyBundle.Builder builder = GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(devicePreKeys.ecSignedPreKey().keyId()) + .setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey())) + .setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature())) + .build()) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(devicePreKeys.kemSignedPreKey().keyId()) + .setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey())) + .setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature())) + .build()); + devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder() .setKeyId(ecPreKey.keyId()) .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) - .build()); - } else if (preKey instanceof ECSignedPreKey ecSignedPreKey) { - builder.setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(ecSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) - .build()); - } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { - builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(kemSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) - .build()); - } else { - throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); - } - - return builder; - }) - // Cast device IDs to `int` to match data types in the response object’s protobuf definition - .map(builder -> Tuples.of((int) device.getId(), builder.build()))) + .build())); + // Cast device IDs to `int` to match data types in the response object’s protobuf definition + return Tuples.of((int) device.getId(), builder.build()); + })) + // If there were no devices with valid prekey bundles in the account, the account is gone + .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) .collectMap(Tuple2::getT1, Tuple2::getT2) .map(preKeyBundles -> GetPreKeysResponse.newBuilder() - .setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize())) + .setIdentityKey(ByteString + .copyFrom(targetAccount.getIdentityKey(targetServiceIdentifier.identityType()) + .serialize())) .putAllPreKeys(preKeyBundles) .build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index ad09ba296..3799a5771 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -136,7 +136,7 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { .flatMap(Mono::justOrEmpty)) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) .flatMap(targetAccount -> - KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), deviceId, keysManager)); + KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager)); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 94f4a8bb4..37a415140 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -5,20 +5,31 @@ package org.whispersystems.textsecuregcm.storage; +import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Metrics; import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.util.Futures; +import org.whispersystems.textsecuregcm.util.Optionals; import reactor.core.publisher.Flux; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; +import javax.annotation.Nullable; public class KeysManager { + // KeysController for backwards compatibility + private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys"); private final SingleUseECPreKeyStore ecPreKeys; private final SingleUseKEMPreKeyStore pqPreKeys; @@ -115,11 +126,13 @@ public class KeysManager { } - public CompletableFuture> takeEC(final UUID identifier, final byte deviceId) { + @VisibleForTesting + CompletableFuture> takeEC(final UUID identifier, final byte deviceId) { return ecPreKeys.take(identifier, deviceId); } - public CompletableFuture> takePQ(final UUID identifier, final byte deviceId) { + @VisibleForTesting + CompletableFuture> takePQ(final UUID identifier, final byte deviceId) { final boolean enrolledInPagedKeys = experimentEnrollmentManager.isEnrolled(identifier, PAGED_KEYS_EXPERIMENT_NAME); return tagTakePQ(pagedPqPreKeys.take(identifier, deviceId), PQSource.PAGE, enrolledInPagedKeys) .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey @@ -209,4 +222,36 @@ public class KeysManager { public CompletableFuture pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) { return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId); } + + public record DevicePreKeys( + ECSignedPreKey ecSignedPreKey, + Optional ecPreKey, + KEMSignedPreKey kemSignedPreKey) {} + + public CompletableFuture> takeDevicePreKeys( + final byte deviceId, + final ServiceIdentifier serviceIdentifier, + final @Nullable String userAgent) { + final UUID uuid = serviceIdentifier.uuid(); + return Futures.zipWith( + this.takeEC(uuid, deviceId), + this.getEcSignedPreKey(uuid, deviceId), + this.takePQ(uuid, deviceId), + (maybeUnsignedEcPreKey, maybeSignedEcPreKey, maybePqPreKey) -> { + + Metrics.counter(GET_KEYS_COUNTER_NAME, Tags.of( + UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of("identityType", serviceIdentifier.identityType().name()), + Tag.of("oneTimeEcKeyAvailable", String.valueOf(maybeUnsignedEcPreKey.isPresent())), + Tag.of("signedEcKeyAvailable", String.valueOf(maybeSignedEcPreKey.isPresent())), + Tag.of("pqKeyAvailable", String.valueOf(maybePqPreKey.isPresent())))) + .increment(); + + // The pq prekey and signed EC prekey should never be null for an existing account. This should only happen + // if the account or device has been removed and the read was split, so we can return empty in those cases. + return Optionals.zipWith(maybeSignedEcPreKey, maybePqPreKey, (signedEcPreKey, pqPreKey) -> + new DevicePreKeys(signedEcPreKey, maybeUnsignedEcPreKey, pqPreKey)); + }) + .toCompletableFuture(); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/Futures.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/Futures.java new file mode 100644 index 000000000..a7471aca5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/Futures.java @@ -0,0 +1,21 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import java.util.concurrent.CompletionStage; +import org.apache.commons.lang3.function.TriFunction; + +public class Futures { + + public static CompletionStage zipWith( + CompletionStage futureT, + CompletionStage futureU, + CompletionStage futureV, + TriFunction fun) { + + return futureT.thenCompose(t -> futureU.thenCombine(futureV, (u, v) -> fun.apply(t, u, v))); + } +} diff --git a/service/src/main/proto/org/signal/chat/keys.proto b/service/src/main/proto/org/signal/chat/keys.proto index 4db5b11aa..368bd432e 100644 --- a/service/src/main/proto/org/signal/chat/keys.proto +++ b/service/src/main/proto/org/signal/chat/keys.proto @@ -195,8 +195,7 @@ message GetPreKeysResponse { /** * A one-time KEM pre-key (or a last-resort KEM pre-key) for the targeted - * account/device/identity. May not be set if the targeted device has not - * yet uploaded any KEM pre-keys. + * account/device/identity. */ common.KemSignedPreKey kem_one_time_pre_key = 3; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index a95937469..e2ad61184 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyString; @@ -35,6 +36,7 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -63,6 +65,7 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; +import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.SetKeysRequest; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; @@ -93,6 +96,7 @@ class KeysControllerTest { private static final UUID EXISTS_UUID = UUID.randomUUID(); private static final UUID EXISTS_PNI = UUID.randomUUID(); private static final AciServiceIdentifier EXISTS_ACI = new AciServiceIdentifier(EXISTS_UUID); + private static final PniServiceIdentifier EXISTS_PNI_SERVICE_ID = new PniServiceIdentifier(EXISTS_PNI); private static final UUID OTHER_UUID = UUID.randomUUID(); private static final AciServiceIdentifier OTHER_ACI = new AciServiceIdentifier(OTHER_UUID); @@ -101,16 +105,9 @@ class KeysControllerTest { private static final AciServiceIdentifier NOT_EXISTS_ACI = new AciServiceIdentifier(NOT_EXISTS_UUID); private static final byte SAMPLE_DEVICE_ID = 1; - private static final byte SAMPLE_DEVICE_ID2 = 2; - private static final byte SAMPLE_DEVICE_ID3 = 3; - private static final byte SAMPLE_DEVICE_ID4 = 4; private static final int SAMPLE_REGISTRATION_ID = 999; - private static final int SAMPLE_REGISTRATION_ID2 = 1002; - private static final int SAMPLE_REGISTRATION_ID4 = 1555; - private static final int SAMPLE_PNI_REGISTRATION_ID = 1717; - private static final int SAMPLE_PNI_REGISTRATION_ID2 = 1718; private final ECKeyPair IDENTITY_KEY_PAIR = ECKeyPair.generate(); private final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey()); @@ -119,34 +116,21 @@ class KeysControllerTest { private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()); private final ECPreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); - private final ECPreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); - private final ECPreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334); - private final ECPreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336); - private final ECPreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777); private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, ECKeyPair.generate()); - private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, ECKeyPair.generate()); - private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, ECKeyPair.generate()); - private final KEMSignedPreKey SAMPLE_PQ_KEY4 = KeysHelper.signedKEMPreKey(7676, ECKeyPair.generate()); - private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, ECKeyPair.generate()); private final ECSignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR); - private final ECSignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR); - private final ECSignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR); - private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR); - private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR); - private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR); - private final ECSignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR); - private final ECSignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR); - private final static KeysManager KEYS = mock(KeysManager.class ); - private final static AccountsManager accounts = mock(AccountsManager.class ); - private final static Account existsAccount = mock(Account.class ); - private static final RateLimiters rateLimiters = mock(RateLimiters.class); - private static final RateLimiter rateLimiter = mock(RateLimiter.class ); + private final static KeysManager KEYS = mock(KeysManager.class); + private final static AccountsManager accounts = mock(AccountsManager.class); + private final static Account existsAccount = mock(Account.class); + + private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate(); @@ -163,9 +147,6 @@ class KeysControllerTest { .addResource(new RateLimitExceededExceptionMapper()) .build(); - private Device sampleDevice; - private Device sampleDevice2; - private record WeaklyTypedPreKey(long keyId, @JsonSerialize(using = ByteArrayAdapter.Serializing.class) @@ -198,34 +179,33 @@ class KeysControllerTest { byte[] identityKey) { } + private Device createSampleDevice(byte deviceId, int registrationId, int pniRegistrationId) { + final Device sampleDevice = mock(Device.class); + when(sampleDevice.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId); + when(sampleDevice.getRegistrationId(IdentityType.PNI)).thenReturn(pniRegistrationId); + when(sampleDevice.getId()).thenReturn(deviceId); + + return sampleDevice; + } + @BeforeEach void setup() { clock.unpin(); - sampleDevice = mock(Device.class); - sampleDevice2 = mock(Device.class); - final Device sampleDevice3 = mock(Device.class); - final Device sampleDevice4 = mock(Device.class); - - final List allDevices = List.of(sampleDevice, sampleDevice2, sampleDevice3, sampleDevice4); - - final byte sampleDeviceId = 1; - final byte sampleDevice2Id = 2; - final byte sampleDevice3Id = 3; - final byte sampleDevice4Id = 4; - AccountsHelper.setupMockUpdate(accounts); - when(sampleDevice.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID); - when(sampleDevice2.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2); - when(sampleDevice3.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2); - when(sampleDevice4.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID4); - when(sampleDevice.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID); - when(sampleDevice2.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID2); - when(sampleDevice.getId()).thenReturn(sampleDeviceId); - when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); - when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); - when(sampleDevice4.getId()).thenReturn(sampleDevice4Id); + final Device sampleDevice = + createSampleDevice(SAMPLE_DEVICE_ID, SAMPLE_REGISTRATION_ID, SAMPLE_PNI_REGISTRATION_ID); + + final KeysManager.DevicePreKeys aciKeys = + new KeysManager.DevicePreKeys(SAMPLE_SIGNED_KEY, Optional.of(SAMPLE_KEY), SAMPLE_PQ_KEY); + final KeysManager.DevicePreKeys pniKeys = + new KeysManager.DevicePreKeys(SAMPLE_SIGNED_PNI_KEY, Optional.of(SAMPLE_KEY_PNI), SAMPLE_PQ_KEY_PNI); + + when(KEYS.takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_ACI), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(aciKeys))); + when(KEYS.takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_PNI_SERVICE_ID), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(pniKeys))); when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); when(existsAccount.isIdentifiedBy(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(true); @@ -233,12 +213,8 @@ class KeysControllerTest { when(existsAccount.isIdentifiedBy(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(true); when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID); when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI); - when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice)); - when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2)); - when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3)); - when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4)); - when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty()); - when(existsAccount.getDevices()).thenReturn(allDevices); + when(existsAccount.getDevice(SAMPLE_DEVICE_ID)).thenReturn(Optional.of(sampleDevice)); + when(existsAccount.getDevices()).thenReturn(List.of(sampleDevice)); when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY); when(existsAccount.getIdentityKey(IdentityType.PNI)).thenReturn(PNI_IDENTITY_KEY); when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER); @@ -264,51 +240,12 @@ class KeysControllerTest { when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any())) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - when(KEYS.storeKemOneTimePreKeys(any(), anyByte(), any())) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - when(KEYS.storePqLastResort(any(), anyByte(), any())) .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - - when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDeviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY))); - - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice2Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY2))); - - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice3Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY3))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDeviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice2Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY2))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice3Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY3))); - - when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takeEC(EXISTS_PNI, sampleDeviceId)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI))); - when(KEYS.takePQ(EXISTS_PNI, sampleDeviceId)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI))); - - when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); - when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); - - when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId())) - .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_SIGNED_KEY))); - - when(KEYS.getEcSignedPreKey(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId())) - .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_PNI_SIGNED_KEY))); + when(KEYS.storeEcSignedPreKeys(any(), anyByte(), any())) + .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); } @AfterEach @@ -326,6 +263,11 @@ class KeysControllerTest { @Test void validKeyStatusTest() { + when(KEYS.getEcCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(5)); + when(KEYS.getPqCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(5)); + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(KeysHelper.signedECPreKey(123, IDENTITY_KEY_PAIR)))); + PreKeyCount result = resources.getJerseyTest() .target("/v2/keys") .request() @@ -355,9 +297,7 @@ class KeysControllerTest { assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_ACI), any()); verifyNoMoreInteractions(KEYS); } @@ -377,9 +317,7 @@ class KeysControllerTest { assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_ACI), any()); verifyNoMoreInteractions(KEYS); } @@ -398,9 +336,7 @@ class KeysControllerTest { assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_PNI_SERVICE_ID), any()); verifyNoMoreInteractions(KEYS); } @@ -420,9 +356,7 @@ class KeysControllerTest { assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_PNI_SERVICE_ID), any()); verifyNoMoreInteractions(KEYS); } @@ -456,64 +390,40 @@ class KeysControllerTest { assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_ACI), any()); verifyNoMoreInteractions(KEYS); } - private enum RereadBehavior { - ACCOUNT_MISSING, - DEVICE_MISSING, - REG_ID_CHANGED, - PRESENT - } - - @ParameterizedTest - @EnumSource(RereadBehavior.class) - void testGetKeysMissingLastResort(RereadBehavior rereadBehavior) { - when(KEYS.takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID2)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2))); - - when(KEYS.takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID2)) - // Missing PQ key + @Test + void testNoKeysForDevice() { + final List devices = List.of( + createSampleDevice((byte) 1, 2, 3), + createSampleDevice((byte) 4, 5, 6)); + // device 1 is missing required prekeys, device 4 is missing an optional EC prekey + when(KEYS.takeDevicePreKeys(eq((byte) 4), eq(EXISTS_PNI_SERVICE_ID), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of( + new KeysManager.DevicePreKeys(SAMPLE_SIGNED_PNI_KEY, Optional.empty(), SAMPLE_PQ_KEY_PNI)))); + when(KEYS.takeDevicePreKeys(eq((byte) 1), eq(EXISTS_PNI_SERVICE_ID), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - switch (rereadBehavior) { - case ACCOUNT_MISSING -> when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI))) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - case DEVICE_MISSING -> when(existsAccount.getDevice(SAMPLE_DEVICE_ID2)) - .thenReturn(Optional.empty()); - case REG_ID_CHANGED -> when(sampleDevice2.getRegistrationId(IdentityType.PNI)) - .thenReturn(SAMPLE_PNI_REGISTRATION_ID2) - .thenReturn(SAMPLE_PNI_REGISTRATION_ID2 + 1); - case PRESENT -> { - } - } + when(existsAccount.getDevice((byte) 1)).thenReturn(Optional.of(devices.get(0))); + when(existsAccount.getDevice((byte) 4)).thenReturn(Optional.of(devices.get(1))); + when(existsAccount.getDevices()).thenReturn(devices); - when(existsAccount.getDevices()).thenReturn(List.of(sampleDevice, sampleDevice2)); - - Response response = resources.getJerseyTest() + PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/PNI:%s/*", EXISTS_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(); - if (rereadBehavior == RereadBehavior.PRESENT) { - // The device was missing a last resort prekey which should be impossible - assertThat(response.getStatus()).isEqualTo(500); - } else { - // In the other cases, the device plausibly disappeared so we can just leave that device out - final PreKeyResponse result = response.readEntity(PreKeyResponse.class); - assertThat(result.getDevicesCount()).isEqualTo(1); - assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); - assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY); - } - } + .get(PreKeyResponse.class); + // Should drop device 1 and keep device 4 + assertThat(results.getDevicesCount()).isEqualTo(1); + final PreKeyResponseItem result = results.getDevice((byte) 4); + assertEquals(6, result.getRegistrationId()); + assertNull(result.getPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getSignedPreKey()); + assertEquals(SAMPLE_PQ_KEY_PNI, result.getPqPreKey()); + } @ParameterizedTest @MethodSource @@ -545,9 +455,7 @@ class KeysControllerTest { assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeDevicePreKeys(eq(SAMPLE_DEVICE_ID), eq(EXISTS_ACI), any()); } verifyNoMoreInteractions(KEYS); @@ -617,162 +525,54 @@ class KeysControllerTest { } - @Test - void validMultiRequestTestV2() { - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2))); - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); + @ParameterizedTest + @EnumSource + void validMultiRequestTestV2(IdentityType identityType) { + final ServiceIdentifier serviceIdentifier = switch (identityType) { + case ACI -> new AciServiceIdentifier(EXISTS_UUID); + case PNI -> new PniServiceIdentifier(EXISTS_PNI); + }; - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY4))); + final List devices = new ArrayList<>(); + final List devicePreKeys = new ArrayList<>(); + for (int i = 0; i < 4; i++) { + devices.add(createSampleDevice((byte) i, i + 100, i + 200)); + + final ECSignedPreKey signedEcPreKey = KeysHelper.signedECPreKey(i + 300, IDENTITY_KEY_PAIR); + final ECPreKey ecPreKey = KeysHelper.ecPreKey(i + 400); + final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(i + 500, ECKeyPair.generate()); + devicePreKeys.add(new KeysManager.DevicePreKeys(signedEcPreKey, Optional.of(ecPreKey), kemSignedPreKey)); + + when(existsAccount.getDevice((byte) i)).thenReturn(Optional.of(devices.getLast())); + when(KEYS.takeDevicePreKeys(eq((byte) i), eq(serviceIdentifier), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(devicePreKeys.getLast()))); + } + when(existsAccount.getDevices()).thenReturn(devices); PreKeyResponse results = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) + .target(String.format("/v2/keys/%s/*", serviceIdentifier.toServiceIdentifierString())) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(PreKeyResponse.class); assertThat(results.getDevicesCount()).isEqualTo(4); - assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); + assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(identityType)); - ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); - ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey(); - long registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId(); - byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId(); + for (int i = 0; i < 4; i++) { + final PreKeyResponseItem result = results.getDevice((byte) i); + final KeysManager.DevicePreKeys expectedPreKeys = devicePreKeys.get(i); + final Device expectedDevice = devices.get(i); - assertEquals(SAMPLE_KEY, preKey); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID); + assertEquals(expectedDevice.getRegistrationId(identityType), result.getRegistrationId()); + assertEquals(expectedPreKeys.ecPreKey().orElseThrow(), result.getPreKey()); + assertEquals(expectedPreKeys.ecSignedPreKey(), result.getSignedPreKey()); + assertEquals(expectedPreKeys.kemSignedPreKey(), result.getPqPreKey()); - signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey(); - preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey(); - registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId(); - deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId(); - - assertEquals(SAMPLE_KEY2, preKey); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); - assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2); - - signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey(); - preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey(); - registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId(); - deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId(); - - assertEquals(SAMPLE_KEY4, preKey); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); - assertThat(signedPreKey).isNull(); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4); - - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); + verify(KEYS).takeDevicePreKeys(eq((byte) i), eq(serviceIdentifier), any()); + } verifyNoMoreInteractions(KEYS); } - @Test - void validMultiRequestPqTestV2() { - when(KEYS.takeEC(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.takePQ(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); - when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); - when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( - CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY4))); - - PreKeyResponse results = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) - .queryParam("pq", "true") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(PreKeyResponse.class); - - assertThat(results.getDevicesCount()).isEqualTo(4); - assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); - - ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); - ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey(); - KEMSignedPreKey pqPreKey = results.getDevice(SAMPLE_DEVICE_ID).getPqPreKey(); - int registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId(); - byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId(); - - assertEquals(SAMPLE_KEY, preKey); - assertEquals(SAMPLE_PQ_KEY, pqPreKey); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID); - - signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey(); - preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey(); - pqPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getPqPreKey(); - registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId(); - deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId(); - - assertThat(preKey).isNull(); - assertEquals(SAMPLE_PQ_KEY2, pqPreKey); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); - assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2); - - signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey(); - preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey(); - pqPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getPqPreKey(); - registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId(); - deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId(); - - assertEquals(SAMPLE_KEY4, preKey); - assertThat(pqPreKey).isEqualTo(SAMPLE_PQ_KEY4); - assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); - assertThat(signedPreKey).isNull(); - assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4); - - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); - verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID3); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); - verifyNoMoreInteractions(KEYS); - } - - @Test void invalidRequestTestV2() { Response response = resources.getJerseyTest() @@ -786,6 +586,7 @@ class KeysControllerTest { @Test void anotherInvalidRequestTestV2() { + when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty()); Response response = resources.getJerseyTest() .target(String.format("/v2/keys/%s/22", EXISTS_UUID)) .request() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index 2b1eae68b..23756e902 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.grpc; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -100,15 +101,11 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest new PniServiceIdentifier(identifier); + case ACI -> new AciServiceIdentifier(identifier); + }; when(targetAccount.getUuid()).thenReturn(UUID.randomUUID()); when(targetAccount.getIdentifier(identityType)).thenReturn(identifier); when(targetAccount.getIdentityKey(identityType)).thenReturn(identityKey); - when(accountsManager.getByServiceIdentifierAsync(argThat(serviceIdentifier -> serviceIdentifier.uuid().equals(identifier)))) + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); - final Map ecOneTimePreKeys = new HashMap<>(); - final Map kemPreKeys = new HashMap<>(); - final Map ecSignedPreKeys = new HashMap<>(); + final Map devicePreKeysMap = new HashMap<>(); final Map devices = new HashMap<>(); + final Map expectedPreKeyBundles = new HashMap<>(); final byte deviceId1 = 1; final byte deviceId2 = 2; for (final byte deviceId : List.of(deviceId1, deviceId2)) { - ecOneTimePreKeys.put(deviceId, new ECPreKey(1, ECKeyPair.generate().getPublicKey())); - kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair)); - ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair)); + + final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(3, identityKeyPair); + final Optional maybeEcPreKey = Optional + .of(new ECPreKey(1, ECKeyPair.generate().getPublicKey())) + .filter(_ -> deviceId == deviceId1); + final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(2, identityKeyPair); + + devicePreKeysMap.put(deviceId, new KeysManager.DevicePreKeys(ecSignedPreKey, maybeEcPreKey, kemSignedPreKey)); + + final GetPreKeysResponse.PreKeyBundle.Builder builder = GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build()); + maybeEcPreKey.ifPresent(ecPreKey -> builder + .setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build())); + expectedPreKeyBundles.put(deviceId, builder.build()); final Device device = mock(Device.class); when(device.getId()).thenReturn(deviceId); @@ -494,14 +521,9 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest(devices.values())); - ecOneTimePreKeys.forEach((deviceId, preKey) -> when(keysManager.takeEC(identifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); - - ecSignedPreKeys.forEach((deviceId, preKey) -> when(keysManager.getEcSignedPreKey(identifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); - - kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + devicePreKeysMap.forEach((deviceId, preKeys) -> when(keysManager.takeDevicePreKeys(eq(deviceId), + eq(serviceIdentifier), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(preKeys)))); { final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() @@ -514,30 +536,12 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest keys = + keysManager.takeDevicePreKeys(DEVICE_ID, ACI_SERVICE_IDENTIFIER, null).join(); + + assertEquals(keys.isPresent(), switch (missingKeyType) { + // We should successfully get keys if every key is present, or if only EC one-time keys are missing + case EC, NONE -> true; + // If the signed EC key or the last-resort PQ key is missing, we shouldn't get keys back + case SIGNED_EC, PQ -> false; + }); + + final boolean hasEcPreKey = keys.flatMap(KeysManager.DevicePreKeys::ecPreKey).isPresent(); + assertEquals(hasEcPreKey, missingKeyType == MissingKeyType.NONE); + } + private static ECPreKey generateTestPreKey(final long keyId) { return new ECPreKey(keyId, ECKeyPair.generate().getPublicKey()); }