diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java index 29783ed5e..50d6b5cb1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java @@ -10,25 +10,25 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Clock; import java.util.Arrays; +import java.util.concurrent.Flow; import org.signal.chat.errors.FailedUnidentifiedAuthorization; import org.signal.chat.errors.NotFound; import org.signal.chat.keys.CheckIdentityKeyRequest; import org.signal.chat.keys.CheckIdentityKeyResponse; import org.signal.chat.keys.GetPreKeysAnonymousRequest; import org.signal.chat.keys.GetPreKeysAnonymousResponse; -import org.signal.chat.keys.ReactorKeysAnonymousGrpc; +import org.signal.chat.keys.SimpleKeysAnonymousGrpc; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.KeysManager; -import reactor.core.publisher.Flux; +import reactor.adapter.JdkFlowAdapter; import reactor.core.publisher.Mono; import reactor.util.function.Tuples; -public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnonymousImplBase { +public class KeysAnonymousGrpcService extends SimpleKeysAnonymousGrpc.KeysAnonymousImplBase { private final AccountsManager accountsManager; private final KeysManager keysManager; @@ -42,7 +42,7 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony } @Override - public Mono getPreKeys(final GetPreKeysAnonymousRequest request) { + public GetPreKeysAnonymousResponse getPreKeys(final GetPreKeysAnonymousRequest request) { final ServiceIdentifier serviceIdentifier = ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getRequest().getTargetIdentifier()); @@ -53,34 +53,34 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony return switch (request.getAuthorizationCase()) { case GROUP_SEND_TOKEN -> { if (!groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier)) { - yield Mono.fromSupplier(() -> GetPreKeysAnonymousResponse.newBuilder() + yield GetPreKeysAnonymousResponse.newBuilder() + .setFailedUnidentifiedAuthorization(FailedUnidentifiedAuthorization.getDefaultInstance()) + .build(); + } + + yield accountsManager.getByServiceIdentifier(serviceIdentifier) + .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)) + .map(accountPreKeyBundles -> GetPreKeysAnonymousResponse.newBuilder().setPreKeys(accountPreKeyBundles).build()) + .orElseGet(() -> GetPreKeysAnonymousResponse.newBuilder() + .setTargetNotFound(NotFound.getDefaultInstance()) + .build()); + } + + case UNIDENTIFIED_ACCESS_KEY -> accountsManager.getByServiceIdentifier(serviceIdentifier) + .filter(targetAccount -> UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())) + .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)) + .map(accountPreKeyBundles -> GetPreKeysAnonymousResponse.newBuilder().setPreKeys(accountPreKeyBundles).build()) + .orElseGet(() -> GetPreKeysAnonymousResponse.newBuilder() .setFailedUnidentifiedAuthorization(FailedUnidentifiedAuthorization.getDefaultInstance()) .build()); - } - yield lookUpAccount(serviceIdentifier) - .flatMap(targetAccount -> KeysGrpcHelper - .getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)) - .map(preKeys -> GetPreKeysAnonymousResponse.newBuilder().setPreKeys(preKeys).build()) - .switchIfEmpty(Mono.fromSupplier(() -> GetPreKeysAnonymousResponse.newBuilder() - .setTargetNotFound(NotFound.getDefaultInstance()) - .build())); - } - case UNIDENTIFIED_ACCESS_KEY -> lookUpAccount(serviceIdentifier) - .filter(targetAccount -> - UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())) - .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier, deviceId, keysManager)) - .map(preKeys -> GetPreKeysAnonymousResponse.newBuilder().setPreKeys(preKeys).build()) - .switchIfEmpty(Mono.fromSupplier(() -> GetPreKeysAnonymousResponse.newBuilder() - .setFailedUnidentifiedAuthorization(FailedUnidentifiedAuthorization.getDefaultInstance()) - .build())); - default -> Mono.error(GrpcExceptions.fieldViolation("authorization", "invalid authorization type")); + default -> throw GrpcExceptions.fieldViolation("authorization", "invalid authorization type"); }; } @Override - public Flux checkIdentityKeys(final Flux requests) { - return requests + public Flow.Publisher checkIdentityKeys(final Flow.Publisher requests) { + return JdkFlowAdapter.publisherToFlowPublisher(JdkFlowAdapter.flowPublisherToFlux(requests) .map(request -> Tuples.of(ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier()), request.getFingerprint().toByteArray())) .flatMap(serviceIdentifierAndFingerprint -> Mono.fromFuture( @@ -89,17 +89,11 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony .filter(account -> !fingerprintMatches(account.getIdentityKey(serviceIdentifierAndFingerprint.getT1() .identityType()), serviceIdentifierAndFingerprint.getT2())) .map(account -> CheckIdentityKeyResponse.newBuilder() - .setTargetIdentifier( - ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifierAndFingerprint.getT1())) - .setIdentityKey(ByteString.copyFrom(account.getIdentityKey(serviceIdentifierAndFingerprint.getT1() - .identityType()).serialize())) - .build()) - ); - } - - private Mono lookUpAccount(final ServiceIdentifier serviceIdentifier) { - return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) - .flatMap(Mono::justOrEmpty); + .setTargetIdentifier( + ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifierAndFingerprint.getT1())) + .setIdentityKey(ByteString.copyFrom(account.getIdentityKey(serviceIdentifierAndFingerprint.getT1() + .identityType()).serialize())) + .build()))); } private static boolean fingerprintMatches(final IdentityKey identityKey, final byte[] fingerprint) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 99717cf43..4fed1e190 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -6,84 +6,85 @@ package org.whispersystems.textsecuregcm.grpc; import com.google.protobuf.ByteString; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.KemSignedPreKey; import org.signal.chat.keys.AccountPreKeyBundles; import org.signal.chat.keys.DevicePreKeyBundle; -import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuple2; -import reactor.util.function.Tuples; -import java.util.Optional; class KeysGrpcHelper { static final byte ALL_DEVICES = 0; - /** - * Fetch {@link AccountPreKeyBundles} from the targetAccount - * - * @param targetAccount The targetAccount to fetch pre-key bundles from - * @param targetServiceIdentifier The serviceIdentifier used to lookup the targetAccount - * @param targetDeviceId The deviceId to retrieve pre-key bundles for, or ALL_DEVICES if all devices should be retrieved - * @param keysManager The {@link KeysManager} to lookup pre-keys from - * @return The requested bundles, or an empty Mono if the keys for the targetAccount do not exist - */ - static Mono getPreKeys(final Account targetAccount, + /// Fetch {@link AccountPreKeyBundles} from the targetAccount + /// + /// @param targetAccount the account to fetch pre-key bundles from + /// @param targetServiceIdentifier the service identifier for the target Account + /// @param targetDeviceId the device ID to retrieve pre-key bundles for, or [#ALL_DEVICES] if all devices should be + /// retrieved + /// @param keysManager The {@link KeysManager} to lookup pre-keys from + /// + /// @return the requested bundles, or empty if the keys for the `targetAccount` do not exist + static Optional getPreKeys(final Account targetAccount, final ServiceIdentifier targetServiceIdentifier, final byte targetDeviceId, final KeysManager keysManager) { - final Flux devices = targetDeviceId == ALL_DEVICES - ? Flux.fromIterable(targetAccount.getDevices()) - : Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId))); + final Stream devices = targetDeviceId == ALL_DEVICES + ? targetAccount.getDevices().stream() + : targetAccount.getDevice(targetDeviceId).stream(); final String userAgent = RequestAttributesUtil.getUserAgent().orElse(null); - return devices - .flatMap(device -> { - final int registrationId = device.getRegistrationId(targetServiceIdentifier.identityType()); - return Mono - .fromFuture(keysManager.takeDevicePreKeys(device.getId(), targetServiceIdentifier, userAgent)) - .flatMap(Mono::justOrEmpty) - .map(devicePreKeys -> { - final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder() - .setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(devicePreKeys.ecSignedPreKey().keyId()) - .setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey())) - .setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature())) - .build()) - .setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(devicePreKeys.kemSignedPreKey().keyId()) - .setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey())) - .setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature())) - .build()) - .setRegistrationId(registrationId); - devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) - .build())); - // Cast device IDs to `int` to match data types in the response object’s protobuf definition - return Tuples.of((int) device.getId(), builder.build()); - }); - }) - .collectMap(Tuple2::getT1, Tuple2::getT2) - .flatMap(preKeyBundles -> { - if (preKeyBundles.isEmpty()) { - // If there were no devices with valid prekey bundles in the account, the account is gone - return Mono.empty(); - } - final IdentityKey targetIdentityKey = targetAccount.getIdentityKey(targetServiceIdentifier.identityType()); - return Mono.just(AccountPreKeyBundles.newBuilder() - .setIdentityKey(ByteString.copyFrom(targetIdentityKey.serialize())) - .putAllDevicePreKeys(preKeyBundles) - .build()); - }); + final Map>> takeKeyFuturesByDeviceId = + devices.collect(Collectors.toMap( + Device::getId, + device -> keysManager.takeDevicePreKeys(device.getId(), targetServiceIdentifier, userAgent))); + + CompletableFuture.allOf(takeKeyFuturesByDeviceId.values().toArray(CompletableFuture[]::new)).join(); + + final Map preKeysByDeviceId = takeKeyFuturesByDeviceId.entrySet().stream() + .filter(entry -> entry.getValue().resultNow().isPresent()) + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().resultNow().orElseThrow())); + + if (preKeysByDeviceId.isEmpty()) { + // If there were no devices with valid prekey bundles in the account, the account is gone + return Optional.empty(); + } + + final AccountPreKeyBundles.Builder preKeyBundlesBuilder = AccountPreKeyBundles.newBuilder() + .setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(targetServiceIdentifier.identityType()).serialize())); + + preKeysByDeviceId.forEach((deviceId, devicePreKeys) -> { + final Device device = targetAccount.getDevice(deviceId).orElseThrow(); + + final DevicePreKeyBundle.Builder builder = DevicePreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(devicePreKeys.ecSignedPreKey().keyId()) + .setPublicKey(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().serializedPublicKey())) + .setSignature(ByteString.copyFrom(devicePreKeys.ecSignedPreKey().signature()))) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(devicePreKeys.kemSignedPreKey().keyId()) + .setPublicKey(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().serializedPublicKey())) + .setSignature(ByteString.copyFrom(devicePreKeys.kemSignedPreKey().signature()))) + .setRegistrationId(device.getRegistrationId(targetServiceIdentifier.identityType())); + + devicePreKeys.ecPreKey().ifPresent(ecPreKey -> builder.setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())))); + + preKeyBundlesBuilder.putDevicePreKeys(deviceId, builder.build()); + }); + + return Optional.of(preKeyBundlesBuilder.build()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index 32530ab9c..b33c3046b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -18,18 +18,19 @@ import org.signal.chat.keys.GetPreKeyCountRequest; import org.signal.chat.keys.GetPreKeyCountResponse; import org.signal.chat.keys.GetPreKeysRequest; import org.signal.chat.keys.GetPreKeysResponse; -import org.signal.chat.keys.ReactorKeysGrpc; import org.signal.chat.keys.SetEcSignedPreKeyRequest; import org.signal.chat.keys.SetKemLastResortPreKeyRequest; import org.signal.chat.keys.SetOneTimeEcPreKeysRequest; import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest; import org.signal.chat.keys.SetPreKeyResponse; +import org.signal.chat.keys.SimpleKeysGrpc; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.kem.KEMPublicKey; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; @@ -39,11 +40,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.KeysManager; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.function.Tuples; -public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { +public class KeysGrpcService extends SimpleKeysGrpc.KeysImplBase { private final AccountsManager accountsManager; private final KeysManager keysManager; @@ -55,11 +53,6 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { private static final StatusRuntimeException INVALID_SIGNATURE_EXCEPTION = GrpcExceptions.fieldViolation("pre_keys", "pre-key signature did not match account identity key"); - private enum PreKeyType { - EC, - KEM - } - public KeysGrpcService(final AccountsManager accountsManager, final KeysManager keysManager, final RateLimiters rateLimiters) { @@ -70,48 +63,34 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { } @Override - public Mono getPreKeyCount(final GetPreKeyCountRequest request) { - return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) - .flatMap(authenticatedDevice -> getAuthenticatedAccount(authenticatedDevice.accountIdentifier()) - .zipWith(Mono.just(authenticatedDevice.deviceId()))) - .flatMapMany(accountAndDeviceId -> Flux.just( - Tuples.of(IdentityType.ACI, accountAndDeviceId.getT1().getUuid(), accountAndDeviceId.getT2()), - Tuples.of(IdentityType.PNI, accountAndDeviceId.getT1().getPhoneNumberIdentifier(), accountAndDeviceId.getT2()) - )) - .flatMap(identityTypeUuidAndDeviceId -> Flux.merge( - Mono.fromFuture(() -> keysManager.getEcCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3())) - .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.EC, ecKeyCount)), + public GetPreKeyCountResponse getPreKeyCount(final GetPreKeyCountRequest request) { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + final Account account = getAuthenticatedAccount(authenticatedDevice.accountIdentifier()); - Mono.fromFuture(() -> keysManager.getPqCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3())) - .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.KEM, ecKeyCount)) - )) - .reduce(GetPreKeyCountResponse.newBuilder(), (builder, tuple) -> { - final IdentityType identityType = tuple.getT1(); - final PreKeyType preKeyType = tuple.getT2(); - final int count = tuple.getT3(); + final CompletableFuture aciEcKeyCountFuture = + keysManager.getEcCount(account.getIdentifier(IdentityType.ACI), authenticatedDevice.deviceId()); - switch (identityType) { - case ACI -> { - switch (preKeyType) { - case EC -> builder.setAciEcPreKeyCount(count); - case KEM -> builder.setAciKemPreKeyCount(count); - } - } - case PNI -> { - switch (preKeyType) { - case EC -> builder.setPniEcPreKeyCount(count); - case KEM -> builder.setPniKemPreKeyCount(count); - } - } - } + final CompletableFuture pniEcKeyCountFuture = + keysManager.getEcCount(account.getIdentifier(IdentityType.PNI), authenticatedDevice.deviceId()); - return builder; - }) - .map(GetPreKeyCountResponse.Builder::build); + final CompletableFuture aciKemKeyCountFuture = + keysManager.getPqCount(account.getIdentifier(IdentityType.ACI), authenticatedDevice.deviceId()); + + final CompletableFuture pniKemKeyCountFuture = + keysManager.getPqCount(account.getIdentifier(IdentityType.PNI), authenticatedDevice.deviceId()); + + CompletableFuture.allOf(aciEcKeyCountFuture, pniEcKeyCountFuture, aciKemKeyCountFuture, pniKemKeyCountFuture).join(); + + return GetPreKeyCountResponse.newBuilder() + .setAciEcPreKeyCount(aciEcKeyCountFuture.resultNow()) + .setPniEcPreKeyCount(pniEcKeyCountFuture.resultNow()) + .setAciKemPreKeyCount(aciKemKeyCountFuture.resultNow()) + .setPniKemPreKeyCount(pniKemKeyCountFuture.resultNow()) + .build(); } @Override - public Mono getPreKeys(final GetPreKeysRequest request) { + public GetPreKeysResponse getPreKeys(final GetPreKeysRequest request) throws RateLimitExceededException { final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); final ServiceIdentifier targetIdentifier = @@ -126,107 +105,115 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { targetIdentifier.uuid() + "." + deviceId; - return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey) - .then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))) - .flatMap(Mono::justOrEmpty) + rateLimiters.getPreKeysLimiter().validate(rateLimitKey); + + return accountsManager.getByServiceIdentifier(targetIdentifier) .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier, deviceId, keysManager)) - .map(bundles -> GetPreKeysResponse.newBuilder() - .setPreKeys(bundles) + .map(accountPreKeyBundles -> GetPreKeysResponse.newBuilder() + .setPreKeys(accountPreKeyBundles) .build()) - .switchIfEmpty(Mono.fromSupplier(() -> GetPreKeysResponse.newBuilder() + .orElseGet(() -> GetPreKeysResponse.newBuilder() .setTargetNotFound(NotFound.getDefaultInstance()) - .build())); + .build()); } @Override - public Mono setOneTimeEcPreKeys(final SetOneTimeEcPreKeysRequest request) { + public SetPreKeyResponse setOneTimeEcPreKeys(final SetOneTimeEcPreKeysRequest request) { if (request.getPreKeysList().isEmpty()) { throw GrpcExceptions.fieldViolation("pre_keys", "pre_keys must be non-empty"); } - return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) - .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), - request.getPreKeysList(), - IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()), - (requestPreKey, ignored) -> checkEcPreKey(requestPreKey), - (identifier, preKeys) -> keysManager.storeEcOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys))); + + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + + storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), + request.getPreKeysList(), + IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()), + (requestPreKey, _) -> checkEcPreKey(requestPreKey), + (identifier, preKeys) -> keysManager.storeEcOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)); + + return SetPreKeyResponse.getDefaultInstance(); } @Override - public Mono setOneTimeKemSignedPreKeys(final SetOneTimeKemSignedPreKeysRequest request) { + public SetPreKeyResponse setOneTimeKemSignedPreKeys(final SetOneTimeKemSignedPreKeysRequest request) { if (request.getPreKeysList().isEmpty()) { throw GrpcExceptions.fieldViolation("pre_keys", "pre_keys must be non-empty"); } - return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) - .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), - request.getPreKeysList(), - IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()), - KeysGrpcService::checkKemSignedPreKey, - (identifier, preKeys) -> keysManager.storeKemOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys))); + + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + + storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), + request.getPreKeysList(), + IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()), + KeysGrpcService::checkKemSignedPreKey, + (identifier, preKeys) -> keysManager.storeKemOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)); + + return SetPreKeyResponse.getDefaultInstance(); } - private Mono storeOneTimePreKeys(final UUID authenticatedAccountUuid, + private void storeOneTimePreKeys(final UUID authenticatedAccountUuid, final List requestPreKeys, final IdentityType identityType, final BiFunction extractPreKeyFunction, final BiFunction, CompletableFuture> storeKeysFunction) { - return getAuthenticatedAccount(authenticatedAccountUuid) - .map(account -> { - final List preKeys = requestPreKeys.stream() - .map(requestPreKey -> extractPreKeyFunction.apply(requestPreKey, account.getIdentityKey(identityType))) - .toList(); + final Account account = getAuthenticatedAccount(authenticatedAccountUuid); - return Tuples.of(account.getIdentifier(identityType), preKeys); - }) - .flatMap(identifierAndPreKeys -> Mono.fromFuture(() -> storeKeysFunction.apply(identifierAndPreKeys.getT1(), identifierAndPreKeys.getT2()))) - .thenReturn(SetPreKeyResponse.newBuilder().build()); + final List preKeys = requestPreKeys.stream() + .map(requestPreKey -> extractPreKeyFunction.apply(requestPreKey, account.getIdentityKey(identityType))) + .toList(); + + storeKeysFunction.apply(account.getIdentifier(identityType), preKeys).join(); } @Override - public Mono setEcSignedPreKey(final SetEcSignedPreKeyRequest request) { - return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) - .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), - request.getIdentityType(), - request.getSignedPreKey(), - KeysGrpcService::checkEcSignedPreKey, - (account, signedPreKey) -> { - final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()); - final UUID identifier = account.getIdentifier(identityType); + public SetPreKeyResponse setEcSignedPreKey(final SetEcSignedPreKeyRequest request) { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); - return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey)); - })); + storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), + request.getIdentityType(), + request.getSignedPreKey(), + KeysGrpcService::checkEcSignedPreKey, + (account, signedPreKey) -> { + final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()); + final UUID identifier = account.getIdentifier(identityType); + + return keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey); + }); + + return SetPreKeyResponse.getDefaultInstance(); } @Override - public Mono setKemLastResortPreKey(final SetKemLastResortPreKeyRequest request) { - return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) - .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), - request.getIdentityType(), - request.getSignedPreKey(), - KeysGrpcService::checkKemSignedPreKey, - (account, lastResortKey) -> { - final UUID identifier = - account.getIdentifier(IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType())); + public SetPreKeyResponse setKemLastResortPreKey(final SetKemLastResortPreKeyRequest request) { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); - return Mono.fromFuture(() -> keysManager.storePqLastResort(identifier, authenticatedDevice.deviceId(), lastResortKey)); - })); + storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), + request.getIdentityType(), + request.getSignedPreKey(), + KeysGrpcService::checkKemSignedPreKey, + (account, lastResortKey) -> { + final UUID identifier = + account.getIdentifier(IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType())); + + return keysManager.storePqLastResort(identifier, authenticatedDevice.deviceId(), lastResortKey); + }); + + return SetPreKeyResponse.getDefaultInstance(); } - private Mono storeRepeatedUseKey(final UUID authenticatedAccountUuid, + private void storeRepeatedUseKey(final UUID authenticatedAccountUuid, final org.signal.chat.common.IdentityType identityType, final R storeKeyRequest, final BiFunction extractKeyFunction, - final BiFunction> storeKeyFunction) { + final BiFunction> storeKeyFunction) { - return getAuthenticatedAccount(authenticatedAccountUuid) - .map(account -> { - final IdentityKey identityKey = account.getIdentityKey(IdentityTypeUtil.fromGrpcIdentityType(identityType)); - final K key = extractKeyFunction.apply(storeKeyRequest, identityKey); + final Account account = getAuthenticatedAccount(authenticatedAccountUuid); - return Tuples.of(account, key); - }) - .flatMap(accountAndKey -> storeKeyFunction.apply(accountAndKey.getT1(), accountAndKey.getT2())) - .thenReturn(SetPreKeyResponse.newBuilder().build()); + final IdentityKey identityKey = account.getIdentityKey(IdentityTypeUtil.fromGrpcIdentityType(identityType)); + final K key = extractKeyFunction.apply(storeKeyRequest, identityKey); + + storeKeyFunction.apply(account, key).join(); } private static ECPreKey checkEcPreKey(final EcPreKey preKey) { @@ -269,8 +256,8 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { } } - private Mono getAuthenticatedAccount(final UUID authenticatedAccountId) { - return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(authenticatedAccountId)) - .map(maybeAccount -> maybeAccount.orElseThrow(() -> GrpcExceptions.invalidCredentials("invalid credentials"))); + private Account getAuthenticatedAccount(final UUID authenticatedAccountId) { + return accountsManager.getByAccountIdentifier(authenticatedAccountId) + .orElseThrow(() -> GrpcExceptions.invalidCredentials("invalid credentials")); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index d3a6af8ce..0de282635 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -98,8 +98,8 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest { - final Account account = invocation.getArgument(0); - final byte deviceId = invocation.getArgument(1); - final Consumer deviceUpdater = invocation.getArgument(2); - - account.getDevice(deviceId).ifPresent(deviceUpdater); - - return CompletableFuture.completedFuture(account); - }); - when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) { @@ -476,8 +466,8 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest devicePreKeysMap = new HashMap<>(); @@ -573,8 +563,8 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder() .setTargetIdentifier(ServiceIdentifier.newBuilder()