mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-22 02:58:02 +01:00
Add a gRPC service for working with pre-keys
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
||||
class UnidentifiedAccessUtilTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void checkUnidentifiedAccess(@Nullable final byte[] targetUak,
|
||||
final boolean unrestrictedUnidentifiedAccess,
|
||||
final byte[] presentedUak,
|
||||
final boolean expectAccessAllowed) {
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.ofNullable(targetUak));
|
||||
when(account.isUnrestrictedUnidentifiedAccess()).thenReturn(unrestrictedUnidentifiedAccess);
|
||||
|
||||
assertEquals(expectAccessAllowed, UnidentifiedAccessUtil.checkUnidentifiedAccess(account, presentedUak));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> checkUnidentifiedAccess() {
|
||||
final byte[] uak = new byte[16];
|
||||
new SecureRandom().nextBytes(uak);
|
||||
|
||||
final byte[] incorrectUak = new byte[uak.length + 1];
|
||||
|
||||
return Stream.of(
|
||||
Arguments.of(null, false, uak, false),
|
||||
Arguments.of(null, true, uak, true),
|
||||
Arguments.of(uak, false, incorrectUak, false),
|
||||
Arguments.of(uak, false, uak, true),
|
||||
Arguments.of(uak, true, incorrectUak, true),
|
||||
Arguments.of(uak, true, uak, true)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
public class MockAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
@Nullable
|
||||
private Pair<UUID, Long> authenticatedDevice;
|
||||
|
||||
public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) {
|
||||
authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
|
||||
}
|
||||
|
||||
public void clearAuthenticatedDevice() {
|
||||
authenticatedDevice = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (authenticatedDevice != null) {
|
||||
final Context context = Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first())
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second());
|
||||
|
||||
return Contexts.interceptCall(context, call, headers, next);
|
||||
}
|
||||
|
||||
return next.startCall(call, headers);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.BindableService;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import io.grpc.util.MutableHandlerRegistry;
|
||||
import org.junit.jupiter.api.extension.AfterEachCallback;
|
||||
import org.junit.jupiter.api.extension.BeforeEachCallback;
|
||||
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
// This is mostly a direct port of
|
||||
// https://github.com/grpc/grpc-java/blob/master/testing/src/main/java/io/grpc/testing/GrpcServerRule.java, but for
|
||||
// JUnit 5.
|
||||
public class GrpcServerExtension implements BeforeEachCallback, AfterEachCallback {
|
||||
|
||||
private ManagedChannel channel;
|
||||
private Server server;
|
||||
private String serverName;
|
||||
private MutableHandlerRegistry serviceRegistry;
|
||||
private boolean useDirectExecutor;
|
||||
|
||||
/**
|
||||
* Returns {@code this} configured to use a direct executor for the {@link ManagedChannel} and
|
||||
* {@link Server}. This can only be called at the rule instantiation.
|
||||
*/
|
||||
public final GrpcServerExtension directExecutor() {
|
||||
if (serverName != null) {
|
||||
throw new IllegalStateException("directExecutor() can only be called at the rule instantiation");
|
||||
}
|
||||
|
||||
useDirectExecutor = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link ManagedChannel} connected to this service.
|
||||
*/
|
||||
public final ManagedChannel getChannel() {
|
||||
return channel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the underlying gRPC {@link Server} for this service.
|
||||
*/
|
||||
public final Server getServer() {
|
||||
return server;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the randomly generated server name for this service.
|
||||
*/
|
||||
public final String getServerName() {
|
||||
return serverName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the service registry for this service. The registry is used to add service instances
|
||||
* (e.g. {@link BindableService} or {@link ServerServiceDefinition} to the server.
|
||||
*/
|
||||
public final MutableHandlerRegistry getServiceRegistry() {
|
||||
return serviceRegistry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeEach(final ExtensionContext extensionContext) throws Exception {
|
||||
serverName = UUID.randomUUID().toString();
|
||||
serviceRegistry = new MutableHandlerRegistry();
|
||||
|
||||
final InProcessServerBuilder serverBuilder = InProcessServerBuilder.forName(serverName)
|
||||
.fallbackHandlerRegistry(serviceRegistry);
|
||||
|
||||
if (useDirectExecutor) {
|
||||
serverBuilder.directExecutor();
|
||||
}
|
||||
|
||||
server = serverBuilder.build().start();
|
||||
|
||||
final InProcessChannelBuilder channelBuilder = InProcessChannelBuilder.forName(serverName);
|
||||
|
||||
if (useDirectExecutor) {
|
||||
channelBuilder.directExecutor();
|
||||
}
|
||||
|
||||
channel = channelBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterEach(final ExtensionContext extensionContext) throws Exception {
|
||||
serverName = null;
|
||||
serviceRegistry = null;
|
||||
|
||||
channel.shutdown();
|
||||
server.shutdown();
|
||||
|
||||
try {
|
||||
channel.awaitTermination(1, TimeUnit.MINUTES);
|
||||
server.awaitTermination(1, TimeUnit.MINUTES);
|
||||
} catch (final InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException(e);
|
||||
} finally {
|
||||
channel.shutdownNow();
|
||||
channel = null;
|
||||
|
||||
server.shutdownNow();
|
||||
server = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Collections;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.chat.common.EcPreKey;
|
||||
import org.signal.chat.common.EcSignedPreKey;
|
||||
import org.signal.chat.common.IdentityType;
|
||||
import org.signal.chat.common.KemSignedPreKey;
|
||||
import org.signal.chat.common.ServiceIdentifier;
|
||||
import org.signal.chat.keys.GetPreKeysAnonymousRequest;
|
||||
import org.signal.chat.keys.GetPreKeysResponse;
|
||||
import org.signal.chat.keys.KeysAnonymousGrpc;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class KeysAnonymousGrpcServiceTest {
|
||||
|
||||
private AccountsManager accountsManager;
|
||||
private KeysManager keysManager;
|
||||
|
||||
private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
keysManager = mock(KeysManager.class);
|
||||
|
||||
final KeysAnonymousGrpcService keysGrpcService =
|
||||
new KeysAnonymousGrpcService(accountsManager, keysManager);
|
||||
|
||||
keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeys() {
|
||||
final Account targetAccount = mock(Account.class);
|
||||
final Device targetDevice = mock(Device.class);
|
||||
|
||||
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
|
||||
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
|
||||
final UUID identifier = UUID.randomUUID();
|
||||
|
||||
final byte[] unidentifiedAccessKey = new byte[16];
|
||||
new SecureRandom().nextBytes(unidentifiedAccessKey);
|
||||
|
||||
when(targetDevice.getId()).thenReturn(Device.MASTER_ID);
|
||||
when(targetDevice.isEnabled()).thenReturn(true);
|
||||
when(targetAccount.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(targetDevice));
|
||||
|
||||
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
|
||||
when(targetAccount.getUuid()).thenReturn(identifier);
|
||||
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
|
||||
when(accountsManager.getByAccountIdentifierAsync(identifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final ECPreKey ecPreKey = new ECPreKey(1, Curve.generateKeyPair().getPublicKey());
|
||||
final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(2, identityKeyPair);
|
||||
final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(3, identityKeyPair);
|
||||
|
||||
when(keysManager.takeEC(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey)));
|
||||
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
|
||||
when(targetDevice.getSignedPreKey()).thenReturn(ecSignedPreKey);
|
||||
|
||||
final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.setDeviceId(Device.MASTER_ID)
|
||||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.build());
|
||||
|
||||
final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
|
||||
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
|
||||
.putPreKeys(Device.MASTER_ID, GetPreKeysResponse.PreKeyBundle.newBuilder()
|
||||
.setEcOneTimePreKey(EcPreKey.newBuilder()
|
||||
.setKeyId(ecPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
|
||||
.build())
|
||||
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(ecSignedPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
|
||||
.build())
|
||||
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(kemSignedPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
assertEquals(expectedResponse, response);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeysIncorrectUnidentifiedAccessKey() {
|
||||
final Account targetAccount = mock(Account.class);
|
||||
|
||||
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
|
||||
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
|
||||
final UUID identifier = UUID.randomUUID();
|
||||
|
||||
final byte[] unidentifiedAccessKey = new byte[16];
|
||||
new SecureRandom().nextBytes(unidentifiedAccessKey);
|
||||
|
||||
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
|
||||
when(targetAccount.getUuid()).thenReturn(identifier);
|
||||
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
|
||||
when(accountsManager.getByAccountIdentifierAsync(identifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException statusRuntimeException =
|
||||
assertThrows(StatusRuntimeException.class,
|
||||
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.setDeviceId(Device.MASTER_ID)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeysAccountNotFound() {
|
||||
when(accountsManager.getByAccountIdentifierAsync(any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.UNAUTHENTICATED, exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
|
||||
void getPreKeysDeviceNotFound(final long deviceId) {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
|
||||
final byte[] unidentifiedAccessKey = new byte[16];
|
||||
new SecureRandom().nextBytes(unidentifiedAccessKey);
|
||||
|
||||
final Account targetAccount = mock(Account.class);
|
||||
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
|
||||
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
|
||||
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
|
||||
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
|
||||
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
|
||||
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,678 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.chat.common.EcPreKey;
|
||||
import org.signal.chat.common.EcSignedPreKey;
|
||||
import org.signal.chat.common.KemSignedPreKey;
|
||||
import org.signal.chat.common.ServiceIdentifier;
|
||||
import org.signal.chat.keys.GetPreKeyCountRequest;
|
||||
import org.signal.chat.keys.GetPreKeyCountResponse;
|
||||
import org.signal.chat.keys.GetPreKeysRequest;
|
||||
import org.signal.chat.keys.GetPreKeysResponse;
|
||||
import org.signal.chat.keys.KeysGrpc;
|
||||
import org.signal.chat.keys.SetEcSignedPreKeyRequest;
|
||||
import org.signal.chat.keys.SetKemLastResortPreKeyRequest;
|
||||
import org.signal.chat.keys.SetOneTimeEcPreKeysRequest;
|
||||
import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
class KeysGrpcServiceTest {
|
||||
|
||||
private AccountsManager accountsManager;
|
||||
private KeysManager keysManager;
|
||||
private RateLimiter preKeysRateLimiter;
|
||||
|
||||
private Device authenticatedDevice;
|
||||
|
||||
private KeysGrpc.KeysBlockingStub keysStub;
|
||||
|
||||
private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
|
||||
private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
|
||||
private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
|
||||
|
||||
private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
|
||||
private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
|
||||
|
||||
@RegisterExtension
|
||||
static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
keysManager = mock(KeysManager.class);
|
||||
preKeysRateLimiter = mock(RateLimiter.class);
|
||||
|
||||
final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
|
||||
|
||||
when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
|
||||
|
||||
final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
|
||||
keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
|
||||
|
||||
authenticatedDevice = mock(Device.class);
|
||||
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
final Account authenticatedAccount = mock(Account.class);
|
||||
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
|
||||
when(authenticatedAccount.getPhoneNumberIdentifier()).thenReturn(AUTHENTICATED_PNI);
|
||||
when(authenticatedAccount.getIdentityKey()).thenReturn(new IdentityKey(ACI_IDENTITY_KEY_PAIR.getPublicKey()));
|
||||
when(authenticatedAccount.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
|
||||
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
|
||||
|
||||
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
|
||||
mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
|
||||
|
||||
GRPC_SERVER_EXTENSION.getServiceRegistry()
|
||||
.addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
|
||||
|
||||
when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
|
||||
when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
|
||||
when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeyCount() {
|
||||
when(keysManager.getEcCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(1));
|
||||
|
||||
when(keysManager.getPqCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(2));
|
||||
|
||||
when(keysManager.getEcCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(3));
|
||||
|
||||
when(keysManager.getPqCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(4));
|
||||
|
||||
assertEquals(GetPreKeyCountResponse.newBuilder()
|
||||
.setAciEcPreKeyCount(1)
|
||||
.setAciKemPreKeyCount(2)
|
||||
.setPniEcPreKeyCount(3)
|
||||
.setPniKemPreKeyCount(4)
|
||||
.build(),
|
||||
keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
|
||||
void setOneTimeEcPreKeys(final org.signal.chat.common.IdentityType identityType) {
|
||||
final List<ECPreKey> preKeys = new ArrayList<>();
|
||||
|
||||
for (int keyId = 0; keyId < 100; keyId++) {
|
||||
preKeys.add(new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()));
|
||||
}
|
||||
|
||||
when(keysManager.storeEcOneTimePreKeys(any(), anyLong(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.addAllPreKeys(preKeys.stream()
|
||||
.map(preKey -> EcPreKey.newBuilder()
|
||||
.setKeyId(preKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
|
||||
.build())
|
||||
.toList())
|
||||
.build());
|
||||
|
||||
final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
|
||||
case ACI -> AUTHENTICATED_ACI;
|
||||
case PNI -> AUTHENTICATED_PNI;
|
||||
};
|
||||
|
||||
verify(keysManager).storeEcOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setOneTimeEcPreKeysWithError() {
|
||||
return Stream.of(
|
||||
// Missing identity type
|
||||
Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
.addPreKeys(EcPreKey.newBuilder()
|
||||
.setKeyId(1)
|
||||
.setPublicKey(ByteString.copyFrom(Curve.generateKeyPair().getPublicKey().serialize()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Invalid public key
|
||||
Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.addPreKeys(EcPreKey.newBuilder()
|
||||
.setKeyId(1)
|
||||
.setPublicKey(ByteString.empty())
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// No keys
|
||||
Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.build())
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
|
||||
void setOneTimeKemSignedPreKeys(final org.signal.chat.common.IdentityType identityType) {
|
||||
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
|
||||
case ACI -> ACI_IDENTITY_KEY_PAIR;
|
||||
case PNI -> PNI_IDENTITY_KEY_PAIR;
|
||||
};
|
||||
|
||||
final List<KEMSignedPreKey> preKeys = new ArrayList<>();
|
||||
|
||||
for (int keyId = 0; keyId < 100; keyId++) {
|
||||
preKeys.add(KeysHelper.signedKEMPreKey(keyId, identityKeyPair));
|
||||
}
|
||||
|
||||
when(keysManager.storeKemOneTimePreKeys(any(), anyLong(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setOneTimeKemSignedPreKeys(
|
||||
SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.addAllPreKeys(preKeys.stream()
|
||||
.map(preKey -> KemSignedPreKey.newBuilder()
|
||||
.setKeyId(preKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(preKey.signature()))
|
||||
.build())
|
||||
.toList())
|
||||
.build());
|
||||
|
||||
final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
|
||||
case ACI -> AUTHENTICATED_ACI;
|
||||
case PNI -> AUTHENTICATED_PNI;
|
||||
};
|
||||
|
||||
verify(keysManager).storeKemOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setOneTimeKemSignedPreKeysWithError() {
|
||||
final KEMSignedPreKey signedPreKey = KeysHelper.signedKEMPreKey(1, ACI_IDENTITY_KEY_PAIR);
|
||||
|
||||
return Stream.of(
|
||||
// Missing identity type
|
||||
Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.addPreKeys(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(1)
|
||||
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Invalid public key
|
||||
Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.addPreKeys(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(1)
|
||||
.setPublicKey(ByteString.empty())
|
||||
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Invalid signature
|
||||
Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.addPreKeys(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(1)
|
||||
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.empty())
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// No keys
|
||||
Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.build())
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
|
||||
void setSignedPreKey(final org.signal.chat.common.IdentityType identityType) {
|
||||
when(accountsManager.updateDeviceAsync(any(), anyLong(), any())).thenAnswer(invocation -> {
|
||||
final Account account = invocation.getArgument(0);
|
||||
final long deviceId = invocation.getArgument(1);
|
||||
final Consumer<Device> deviceUpdater = invocation.getArgument(2);
|
||||
|
||||
account.getDevice(deviceId).ifPresent(deviceUpdater);
|
||||
|
||||
return CompletableFuture.completedFuture(account);
|
||||
});
|
||||
|
||||
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
|
||||
case ACI -> ACI_IDENTITY_KEY_PAIR;
|
||||
case PNI -> PNI_IDENTITY_KEY_PAIR;
|
||||
};
|
||||
|
||||
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(signedPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
|
||||
.build())
|
||||
.build());
|
||||
|
||||
switch (identityType) {
|
||||
case IDENTITY_TYPE_ACI -> {
|
||||
verify(authenticatedDevice).setSignedPreKey(signedPreKey);
|
||||
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
|
||||
}
|
||||
|
||||
case IDENTITY_TYPE_PNI -> {
|
||||
verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey);
|
||||
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setSignedPreKeyWithError() {
|
||||
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, ACI_IDENTITY_KEY_PAIR);
|
||||
|
||||
return Stream.of(
|
||||
// Missing identity type
|
||||
Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(signedPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Invalid public key
|
||||
Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(signedPreKey.keyId())
|
||||
.setPublicKey(ByteString.empty())
|
||||
.setSignature(ByteString.copyFrom(signedPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Invalid signature
|
||||
Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(signedPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.empty())
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Missing key
|
||||
Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.build())
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
|
||||
void setLastResortPreKey(final org.signal.chat.common.IdentityType identityType) {
|
||||
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
|
||||
|
||||
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
|
||||
case ACI -> ACI_IDENTITY_KEY_PAIR;
|
||||
case PNI -> PNI_IDENTITY_KEY_PAIR;
|
||||
};
|
||||
|
||||
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setSignedPreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(lastResortPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
|
||||
.build())
|
||||
.build());
|
||||
|
||||
final UUID expectedIdentifier = switch (identityType) {
|
||||
case IDENTITY_TYPE_ACI -> AUTHENTICATED_ACI;
|
||||
case IDENTITY_TYPE_PNI -> AUTHENTICATED_PNI;
|
||||
case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw new AssertionError("Bad identity type");
|
||||
};
|
||||
|
||||
verify(keysManager).storePqLastResort(expectedIdentifier, Map.of(AUTHENTICATED_DEVICE_ID, lastResortPreKey));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
|
||||
|
||||
assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
private static Stream<Arguments> setLastResortPreKeyWithError() {
|
||||
final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, ACI_IDENTITY_KEY_PAIR);
|
||||
|
||||
return Stream.of(
|
||||
// No identity type
|
||||
Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setSignedPreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(lastResortPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Bad public key
|
||||
Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setSignedPreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(lastResortPreKey.keyId())
|
||||
.setPublicKey(ByteString.empty())
|
||||
.setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Bad signature
|
||||
Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setSignedPreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(lastResortPreKey.keyId())
|
||||
.setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
|
||||
.setSignature(ByteString.empty())
|
||||
.build())
|
||||
.build()),
|
||||
|
||||
// Missing key
|
||||
Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.build())
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
|
||||
void getPreKeys(final org.signal.chat.common.IdentityType identityType) {
|
||||
final Account targetAccount = mock(Account.class);
|
||||
|
||||
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
|
||||
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
|
||||
final UUID identifier = UUID.randomUUID();
|
||||
|
||||
if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
|
||||
when(targetAccount.getUuid()).thenReturn(identifier);
|
||||
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
|
||||
when(accountsManager.getByAccountIdentifierAsync(identifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
} else {
|
||||
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(targetAccount.getPhoneNumberIdentifier()).thenReturn(identifier);
|
||||
when(targetAccount.getPhoneNumberIdentityKey()).thenReturn(identityKey);
|
||||
when(accountsManager.getByPhoneNumberIdentifierAsync(identifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
}
|
||||
|
||||
final Map<Long, ECPreKey> ecOneTimePreKeys = new HashMap<>();
|
||||
final Map<Long, KEMSignedPreKey> kemPreKeys = new HashMap<>();
|
||||
final Map<Long, ECSignedPreKey> ecSignedPreKeys = new HashMap<>();
|
||||
|
||||
final Map<Long, Device> devices = new HashMap<>();
|
||||
|
||||
for (final long deviceId : List.of(1, 2)) {
|
||||
ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey()));
|
||||
kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair));
|
||||
ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair));
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.isEnabled()).thenReturn(true);
|
||||
|
||||
if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
|
||||
when(device.getSignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
|
||||
} else {
|
||||
when(device.getPhoneNumberIdentitySignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
|
||||
}
|
||||
|
||||
devices.put(deviceId, device);
|
||||
when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
}
|
||||
|
||||
when(targetAccount.getDevices()).thenReturn(new ArrayList<>(devices.values()));
|
||||
|
||||
ecOneTimePreKeys.forEach((deviceId, preKey) -> when(keysManager.takeEC(identifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
|
||||
|
||||
kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
|
||||
|
||||
{
|
||||
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.setDeviceId(1)
|
||||
.build());
|
||||
|
||||
final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
|
||||
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
|
||||
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
|
||||
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(ecSignedPreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
|
||||
.build())
|
||||
.setEcOneTimePreKey(EcPreKey.newBuilder()
|
||||
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
|
||||
.build())
|
||||
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(kemPreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
assertEquals(expectedResponse, response);
|
||||
}
|
||||
|
||||
when(keysManager.takeEC(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
{
|
||||
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(identityType)
|
||||
.setUuid(UUIDUtil.toByteString(identifier))
|
||||
.build())
|
||||
.build());
|
||||
|
||||
final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
|
||||
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
|
||||
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
|
||||
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(ecSignedPreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
|
||||
.build())
|
||||
.setEcOneTimePreKey(EcPreKey.newBuilder()
|
||||
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
|
||||
.build())
|
||||
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
|
||||
.setKeyId(kemPreKeys.get(1L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
|
||||
.build())
|
||||
.build())
|
||||
.putPreKeys(2, GetPreKeysResponse.PreKeyBundle.newBuilder()
|
||||
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
|
||||
.setKeyId(ecSignedPreKeys.get(2L).keyId())
|
||||
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(2L).serializedPublicKey()))
|
||||
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(2L).signature()))
|
||||
.build())
|
||||
.build())
|
||||
.build();
|
||||
|
||||
assertEquals(expectedResponse, response);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeysAccountNotFound() {
|
||||
when(accountsManager.getByAccountIdentifierAsync(any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
|
||||
void getPreKeysDeviceNotFound(final long deviceId) {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
|
||||
final Account targetAccount = mock(Account.class);
|
||||
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
|
||||
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
|
||||
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
|
||||
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(accountIdentifier))
|
||||
.build())
|
||||
.setDeviceId(deviceId)
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
void getPreKeysRateLimited() {
|
||||
final Account targetAccount = mock(Account.class);
|
||||
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
|
||||
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
|
||||
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
|
||||
|
||||
when(accountsManager.getByAccountIdentifierAsync(any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
|
||||
|
||||
final Duration retryAfterDuration = Duration.ofMinutes(7);
|
||||
|
||||
when(preKeysRateLimiter.validateReactive(anyString()))
|
||||
.thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
|
||||
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
|
||||
assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
|
||||
.setTargetIdentifier(ServiceIdentifier.newBuilder()
|
||||
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
|
||||
.setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
|
||||
.build())
|
||||
.build()));
|
||||
|
||||
assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
|
||||
assertNotNull(exception.getTrailers());
|
||||
assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user