Add a gRPC service for working with pre-keys

This commit is contained in:
Jon Chambers
2023-07-20 11:10:26 -04:00
committed by GitHub
parent 0188d314ce
commit 5627209fdd
24 changed files with 2112 additions and 23 deletions

View File

@@ -0,0 +1,21 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
public enum IdentityType {
ACI,
PNI;
public static IdentityType fromGrpcIdentityType(final org.signal.chat.common.IdentityType grpcIdentityType) {
return switch (grpcIdentityType) {
case IDENTITY_TYPE_ACI -> ACI;
case IDENTITY_TYPE_PNI -> PNI;
case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw Status.INVALID_ARGUMENT.asRuntimeException();
};
}
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
import org.signal.chat.keys.GetPreKeysAnonymousRequest;
import org.signal.chat.keys.GetPreKeysResponse;
import org.signal.chat.keys.ReactorKeysAnonymousGrpc;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Mono;
public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnonymousImplBase {
private final AccountsManager accountsManager;
private final KeysManager keysManager;
public KeysAnonymousGrpcService(final AccountsManager accountsManager, final KeysManager keysManager) {
this.accountsManager = accountsManager;
this.keysManager = keysManager;
}
@Override
public Mono<GetPreKeysResponse> getPreKeys(final GetPreKeysAnonymousRequest request) {
return KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager)
.switchIfEmpty(Mono.error(Status.UNAUTHENTICATED.asException()))
.flatMap(targetAccount -> {
final IdentityType identityType =
IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType());
return UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
? KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager)
: Mono.error(Status.UNAUTHENTICATED.asException());
});
}
}

View File

@@ -0,0 +1,107 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.util.UUID;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.keys.GetPreKeysResponse;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
class KeysGrpcHelper {
@VisibleForTesting
static final long ALL_DEVICES = 0;
static Mono<Account> findAccount(final ServiceIdentifier targetIdentifier, final AccountsManager accountsManager) {
return Mono.just(IdentityType.fromGrpcIdentityType(targetIdentifier.getIdentityType()))
.flatMap(identityType -> {
final UUID uuid = UUIDUtil.fromByteString(targetIdentifier.getUuid());
return Mono.fromFuture(switch (identityType) {
case ACI -> accountsManager.getByAccountIdentifierAsync(uuid);
case PNI -> accountsManager.getByPhoneNumberIdentifierAsync(uuid);
});
})
.flatMap(Mono::justOrEmpty)
.onErrorMap(IllegalArgumentException.class, throwable -> Status.INVALID_ARGUMENT.asException());
}
static Tuple2<UUID, IdentityKey> getIdentifierAndIdentityKey(final Account account, final IdentityType identityType) {
final UUID identifier = switch (identityType) {
case ACI -> account.getUuid();
case PNI -> account.getPhoneNumberIdentifier();
};
final IdentityKey identityKey = switch (identityType) {
case ACI -> account.getIdentityKey();
case PNI -> account.getPhoneNumberIdentityKey();
};
return Tuples.of(identifier, identityKey);
}
static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount, final IdentityType identityType, final long targetDeviceId, final KeysManager keysManager) {
final Tuple2<UUID, IdentityKey> identifierAndIdentityKey = getIdentifierAndIdentityKey(targetAccount, identityType);
final Flux<Device> devices = targetDeviceId == ALL_DEVICES
? Flux.fromIterable(targetAccount.getDevices())
: Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));
return devices
.filter(Device::isEnabled)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(device -> Mono.zip(Mono.fromFuture(keysManager.takeEC(identifierAndIdentityKey.getT1(), device.getId())),
Mono.fromFuture(keysManager.takePQ(identifierAndIdentityKey.getT1(), device.getId())))
.map(oneTimePreKeys -> {
final ECSignedPreKey ecSignedPreKey = switch (identityType) {
case ACI -> device.getSignedPreKey();
case PNI -> device.getPhoneNumberIdentitySignedPreKey();
};
final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
.build());
oneTimePreKeys.getT1().ifPresent(ecPreKey -> preKeyBundleBuilder.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
.build()));
oneTimePreKeys.getT2().ifPresent(kemSignedPreKey -> preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
.build()));
return Tuples.of(device.getId(), preKeyBundleBuilder.build());
}))
.collectMap(Tuple2::getT1, Tuple2::getT2)
.map(preKeyBundles -> GetPreKeysResponse.newBuilder()
.setIdentityKey(ByteString.copyFrom(identifierAndIdentityKey.getT2().serialize()))
.putAllPreKeys(preKeyBundles)
.build());
}
}

View File

@@ -0,0 +1,307 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.whispersystems.textsecuregcm.grpc.IdentityType.ACI;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey;
import org.signal.chat.keys.GetPreKeyCountRequest;
import org.signal.chat.keys.GetPreKeyCountResponse;
import org.signal.chat.keys.GetPreKeysRequest;
import org.signal.chat.keys.GetPreKeysResponse;
import org.signal.chat.keys.ReactorKeysGrpc;
import org.signal.chat.keys.SetEcSignedPreKeyRequest;
import org.signal.chat.keys.SetKemLastResortPreKeyRequest;
import org.signal.chat.keys.SetOneTimeEcPreKeysRequest;
import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
import org.signal.chat.keys.SetPreKeyResponse;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
private final AccountsManager accountsManager;
private final KeysManager keysManager;
private final RateLimiters rateLimiters;
private static final StatusRuntimeException INVALID_PUBLIC_KEY_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT)
.withDescription("Invalid public key")
.asRuntimeException();
private static final StatusRuntimeException INVALID_SIGNATURE_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT)
.withDescription("Invalid signature")
.asRuntimeException();
private enum PreKeyType {
EC,
KEM
}
public KeysGrpcService(final AccountsManager accountsManager,
final KeysManager keysManager,
final RateLimiters rateLimiters) {
this.accountsManager = accountsManager;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters;
}
@Override
protected Throwable onErrorMap(final Throwable throwable) {
return RateLimitUtil.mapRateLimitExceededException(throwable);
}
@Override
public Mono<GetPreKeyCountResponse> getPreKeyCount(final GetPreKeyCountRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
.flatMap(authenticatedDevice -> Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier()))
.map(maybeAccount -> maybeAccount
.map(account -> Tuples.of(account, authenticatedDevice.deviceId()))
.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException)))
.flatMapMany(accountAndDeviceId -> Flux.just(
Tuples.of(ACI, accountAndDeviceId.getT1().getUuid(), accountAndDeviceId.getT2()),
Tuples.of(IdentityType.PNI, accountAndDeviceId.getT1().getPhoneNumberIdentifier(), accountAndDeviceId.getT2())
))
.flatMap(identityTypeUuidAndDeviceId -> Flux.merge(
Mono.fromFuture(keysManager.getEcCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3()))
.map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.EC, ecKeyCount)),
Mono.fromFuture(keysManager.getPqCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3()))
.map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.KEM, ecKeyCount))
))
.reduce(GetPreKeyCountResponse.newBuilder(), (builder, tuple) -> {
final IdentityType identityType = tuple.getT1();
final PreKeyType preKeyType = tuple.getT2();
final int count = tuple.getT3();
switch (identityType) {
case ACI -> {
switch (preKeyType) {
case EC -> builder.setAciEcPreKeyCount(count);
case KEM -> builder.setAciKemPreKeyCount(count);
}
}
case PNI -> {
switch (preKeyType) {
case EC -> builder.setPniEcPreKeyCount(count);
case KEM -> builder.setPniKemPreKeyCount(count);
}
}
}
return builder;
})
.map(GetPreKeyCountResponse.Builder::build);
}
@Override
public Mono<GetPreKeysResponse> getPreKeys(final GetPreKeysRequest request) {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
final String rateLimitKey;
{
final UUID targetUuid;
try {
targetUuid = UUIDUtil.fromByteString(request.getTargetIdentifier().getUuid());
} catch (final IllegalArgumentException e) {
throw Status.INVALID_ARGUMENT.asRuntimeException();
}
rateLimitKey = authenticatedDevice.accountIdentifier() + "." +
authenticatedDevice.deviceId() + "__" +
targetUuid + "." +
request.getDeviceId();
}
return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey)
.then(KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager))
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(targetAccount -> {
final IdentityType identityType =
IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType());
return KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager);
});
}
@Override
public Mono<SetPreKeyResponse> setOneTimeEcPreKeys(final SetOneTimeEcPreKeysRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
.flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(),
request.getPreKeysList(),
IdentityType.fromGrpcIdentityType(request.getIdentityType()),
(requestPreKey, ignored) -> checkEcPreKey(requestPreKey),
(identifier, preKeys) -> keysManager.storeEcOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)));
}
@Override
public Mono<SetPreKeyResponse> setOneTimeKemSignedPreKeys(final SetOneTimeKemSignedPreKeysRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
.flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(),
request.getPreKeysList(),
IdentityType.fromGrpcIdentityType(request.getIdentityType()),
KeysGrpcService::checkKemSignedPreKey,
(identifier, preKeys) -> keysManager.storeKemOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)));
}
private <K, R> Mono<SetPreKeyResponse> storeOneTimePreKeys(final UUID authenticatedAccountUuid,
final List<R> requestPreKeys,
final IdentityType identityType,
final BiFunction<R, IdentityKey, K> extractPreKeyFunction,
final BiFunction<UUID, List<K>, CompletableFuture<Void>> storeKeysFunction) {
return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid))
.map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
.map(account -> {
final Tuple2<UUID, IdentityKey> identifierAndIdentityKey =
KeysGrpcHelper.getIdentifierAndIdentityKey(account, identityType);
final List<K> preKeys = requestPreKeys.stream()
.map(requestPreKey -> extractPreKeyFunction.apply(requestPreKey, identifierAndIdentityKey.getT2()))
.toList();
if (preKeys.isEmpty()) {
throw Status.INVALID_ARGUMENT.asRuntimeException();
}
return Tuples.of(identifierAndIdentityKey.getT1(), preKeys);
})
.flatMap(identifierAndPreKeys -> Mono.fromFuture(storeKeysFunction.apply(identifierAndPreKeys.getT1(), identifierAndPreKeys.getT2())))
.thenReturn(SetPreKeyResponse.newBuilder().build());
}
@Override
public Mono<SetPreKeyResponse> setEcSignedPreKey(final SetEcSignedPreKeyRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
.flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(),
request.getIdentityType(),
request.getSignedPreKey(),
KeysGrpcService::checkEcSignedPreKey,
(account, signedPreKey) -> {
final Consumer<Device> deviceUpdater = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
case ACI -> device -> device.setSignedPreKey(signedPreKey);
case PNI -> device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
};
final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
case ACI -> account.getUuid();
case PNI -> account.getPhoneNumberIdentifier();
};
return Flux.merge(
Mono.fromFuture(keysManager.storeEcSignedPreKeys(identifier, Map.of(authenticatedDevice.deviceId(), signedPreKey))),
Mono.fromFuture(accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater)))
.then();
}));
}
@Override
public Mono<SetPreKeyResponse> setKemLastResortPreKey(final SetKemLastResortPreKeyRequest request) {
return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
.flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(),
request.getIdentityType(),
request.getSignedPreKey(),
KeysGrpcService::checkKemSignedPreKey,
(account, lastResortKey) -> {
final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
case ACI -> account.getUuid();
case PNI -> account.getPhoneNumberIdentifier();
};
return Mono.fromFuture(keysManager.storePqLastResort(identifier, Map.of(authenticatedDevice.deviceId(), lastResortKey)));
}));
}
private <K, R> Mono<SetPreKeyResponse> storeRepeatedUseKey(final UUID authenticatedAccountUuid,
final org.signal.chat.common.IdentityType identityType,
final R storeKeyRequest,
final BiFunction<R, IdentityKey, K> extractKeyFunction,
final BiFunction<Account, K, Mono<?>> storeKeyFunction) {
return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid))
.map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
.map(account -> {
final IdentityKey identityKey = switch (IdentityType.fromGrpcIdentityType(identityType)) {
case ACI -> account.getIdentityKey();
case PNI -> account.getPhoneNumberIdentityKey();
};
final K key = extractKeyFunction.apply(storeKeyRequest, identityKey);
return Tuples.of(account, key);
})
.flatMap(accountAndKey -> storeKeyFunction.apply(accountAndKey.getT1(), accountAndKey.getT2()))
.thenReturn(SetPreKeyResponse.newBuilder().build());
}
private static ECPreKey checkEcPreKey(final EcPreKey preKey) {
try {
return new ECPreKey(preKey.getKeyId(), new ECPublicKey(preKey.getPublicKey().toByteArray()));
} catch (final InvalidKeyException e) {
throw INVALID_PUBLIC_KEY_EXCEPTION;
}
}
private static ECSignedPreKey checkEcSignedPreKey(final EcSignedPreKey preKey, final IdentityKey identityKey) {
try {
final ECSignedPreKey ecSignedPreKey = new ECSignedPreKey(preKey.getKeyId(),
new ECPublicKey(preKey.getPublicKey().toByteArray()),
preKey.getSignature().toByteArray());
if (ecSignedPreKey.signatureValid(identityKey)) {
return ecSignedPreKey;
} else {
throw INVALID_SIGNATURE_EXCEPTION;
}
} catch (final InvalidKeyException e) {
throw INVALID_PUBLIC_KEY_EXCEPTION;
}
}
private static KEMSignedPreKey checkKemSignedPreKey(final KemSignedPreKey preKey, final IdentityKey identityKey) {
try {
final KEMSignedPreKey kemSignedPreKey = new KEMSignedPreKey(preKey.getKeyId(),
new KEMPublicKey(preKey.getPublicKey().toByteArray()),
preKey.getSignature().toByteArray());
if (kemSignedPreKey.signatureValid(identityKey)) {
return kemSignedPreKey;
} else {
throw INVALID_SIGNATURE_EXCEPTION;
}
} catch (final InvalidKeyException e) {
throw INVALID_PUBLIC_KEY_EXCEPTION;
}
}
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import java.time.Duration;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public class RateLimitUtil {
public static final Metadata.Key<Duration> RETRY_AFTER_DURATION_KEY =
Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
@Override
public String toAsciiString(final Duration value) {
return value.toString();
}
@Override
public Duration parseAsciiString(final String serialized) {
return Duration.parse(serialized);
}
});
public static Throwable mapRateLimitExceededException(final Throwable throwable) {
if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
@Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
.map(duration -> {
final Metadata metadata = new Metadata();
metadata.put(RETRY_AFTER_DURATION_KEY, duration);
return metadata;
}).orElse(null);
return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
}
return throwable;
}
}