Update keys gRPC endpoint to use service identifiers

This commit is contained in:
Jon Chambers
2023-07-21 13:03:01 -04:00
committed by GitHub
parent dc1cb9093a
commit 9df923d916
11 changed files with 229 additions and 172 deletions

View File

@@ -27,7 +27,6 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.IdentityType;
import org.signal.chat.common.KemSignedPreKey;
import org.signal.chat.common.ServiceIdentifier;
import org.signal.chat.keys.GetPreKeysAnonymousRequest;
@@ -39,6 +38,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -86,9 +87,9 @@ class KeysAnonymousGrpcServiceTest {
when(targetAccount.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(targetDevice));
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
when(targetAccount.getUuid()).thenReturn(identifier);
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
when(accountsManager.getByAccountIdentifierAsync(identifier))
when(targetAccount.getIdentifier(IdentityType.ACI)).thenReturn(identifier);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey);
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final ECPreKey ecPreKey = new ECPreKey(1, Curve.generateKeyPair().getPublicKey());
@@ -97,11 +98,11 @@ class KeysAnonymousGrpcServiceTest {
when(keysManager.takeEC(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey)));
when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
when(targetDevice.getSignedPreKey()).thenReturn(ecSignedPreKey);
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.setDeviceId(Device.MASTER_ID)
@@ -144,15 +145,15 @@ class KeysAnonymousGrpcServiceTest {
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
when(targetAccount.getUuid()).thenReturn(identifier);
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
when(accountsManager.getByAccountIdentifierAsync(identifier))
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey);
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(identifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException statusRuntimeException =
assertThrows(StatusRuntimeException.class,
() -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
.setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.setDeviceId(Device.MASTER_ID)
@@ -163,7 +164,7 @@ class KeysAnonymousGrpcServiceTest {
@Test
void getPreKeysAccountNotFound() {
when(accountsManager.getByAccountIdentifierAsync(any()))
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
@@ -188,12 +189,12 @@ class KeysAnonymousGrpcServiceTest {
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =

View File

@@ -11,6 +11,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -59,6 +60,8 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -109,8 +112,10 @@ class KeysGrpcServiceTest {
final Account authenticatedAccount = mock(Account.class);
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getPhoneNumberIdentifier()).thenReturn(AUTHENTICATED_PNI);
when(authenticatedAccount.getIdentityKey()).thenReturn(new IdentityKey(ACI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getIdentifier(IdentityType.PNI)).thenReturn(AUTHENTICATED_PNI);
when(authenticatedAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(ACI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
@@ -172,7 +177,7 @@ class KeysGrpcServiceTest {
.toList())
.build());
final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
final UUID expectedIdentifier = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> AUTHENTICATED_ACI;
case PNI -> AUTHENTICATED_PNI;
};
@@ -218,7 +223,7 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void setOneTimeKemSignedPreKeys(final org.signal.chat.common.IdentityType identityType) {
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> ACI_IDENTITY_KEY_PAIR;
case PNI -> PNI_IDENTITY_KEY_PAIR;
};
@@ -245,7 +250,7 @@ class KeysGrpcServiceTest {
.toList())
.build());
final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
final UUID expectedIdentifier = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> AUTHENTICATED_ACI;
case PNI -> AUTHENTICATED_PNI;
};
@@ -317,7 +322,7 @@ class KeysGrpcServiceTest {
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> ACI_IDENTITY_KEY_PAIR;
case PNI -> PNI_IDENTITY_KEY_PAIR;
};
@@ -401,7 +406,7 @@ class KeysGrpcServiceTest {
void setLastResortPreKey(final org.signal.chat.common.IdentityType identityType) {
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
final ECKeyPair identityKeyPair = switch (IdentityTypeUtil.fromGrpcIdentityType(identityType)) {
case ACI -> ACI_IDENTITY_KEY_PAIR;
case PNI -> PNI_IDENTITY_KEY_PAIR;
};
@@ -478,25 +483,20 @@ class KeysGrpcServiceTest {
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getPreKeys(final org.signal.chat.common.IdentityType identityType) {
void getPreKeys(final org.signal.chat.common.IdentityType grpcIdentityType) {
final Account targetAccount = mock(Account.class);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final UUID identifier = UUID.randomUUID();
if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
when(targetAccount.getUuid()).thenReturn(identifier);
when(targetAccount.getIdentityKey()).thenReturn(identityKey);
when(accountsManager.getByAccountIdentifierAsync(identifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
} else {
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getPhoneNumberIdentifier()).thenReturn(identifier);
when(targetAccount.getPhoneNumberIdentityKey()).thenReturn(identityKey);
when(accountsManager.getByPhoneNumberIdentifierAsync(identifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
}
final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(grpcIdentityType);
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getIdentifier(identityType)).thenReturn(identifier);
when(targetAccount.getIdentityKey(identityType)).thenReturn(identityKey);
when(accountsManager.getByServiceIdentifierAsync(argThat(serviceIdentifier -> serviceIdentifier.uuid().equals(identifier))))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Map<Long, ECPreKey> ecOneTimePreKeys = new HashMap<>();
final Map<Long, KEMSignedPreKey> kemPreKeys = new HashMap<>();
@@ -512,12 +512,7 @@ class KeysGrpcServiceTest {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
when(device.getSignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
} else {
when(device.getPhoneNumberIdentitySignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
}
when(device.getSignedPreKey(identityType)).thenReturn(ecSignedPreKeys.get(deviceId));
devices.put(deviceId, device);
when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
@@ -534,7 +529,7 @@ class KeysGrpcServiceTest {
{
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(identityType)
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.setDeviceId(1)
@@ -569,7 +564,7 @@ class KeysGrpcServiceTest {
{
final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
.setTargetIdentifier(ServiceIdentifier.newBuilder()
.setIdentityType(identityType)
.setIdentityType(grpcIdentityType)
.setUuid(UUIDUtil.toByteString(identifier))
.build())
.build());
@@ -607,7 +602,7 @@ class KeysGrpcServiceTest {
@Test
void getPreKeysAccountNotFound() {
when(accountsManager.getByAccountIdentifierAsync(any()))
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
@@ -628,11 +623,11 @@ class KeysGrpcServiceTest {
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
@SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
@@ -651,11 +646,11 @@ class KeysGrpcServiceTest {
void getPreKeysRateLimited() {
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(accountsManager.getByAccountIdentifierAsync(any()))
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Duration retryAfterDuration = Duration.ofMinutes(7);

View File

@@ -218,7 +218,7 @@ class AccountsManagerTest {
when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(commands.get(eq("Account3::" + aci))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}");
assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent());
assertTrue(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni)).isPresent());
@@ -226,6 +226,29 @@ class AccountsManagerTest {
assertFalse(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(aci)).isPresent());
}
@Test
void testGetByServiceIdentifierAsync() {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString()));
when(asyncCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accounts.getByPhoneNumberIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
assertTrue(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(aci)).join().isPresent());
assertTrue(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(pni)).join().isPresent());
assertFalse(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(pni)).join().isPresent());
assertFalse(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(aci)).join().isPresent());
}
@Test
void testGetAccountByNumberInCache() {
UUID uuid = UUID.randomUUID();
@@ -315,7 +338,7 @@ class AccountsManagerTest {
}
@Test
void testGetByPniInCache() {
void testGetAccountByPniInCache() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
@@ -337,7 +360,7 @@ class AccountsManagerTest {
}
@Test
void testGetByPniInCacheAsync() {
void testGetAccountByPniInCacheAsync() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
@@ -363,7 +386,7 @@ class AccountsManagerTest {
}
@Test
void testGetByUsernameHashInCache() {
void testGetAccountByUsernameHashInCache() {
UUID uuid = UUID.randomUUID();
when(commands.get(eq("UAccountMap::" + BASE_64_URL_USERNAME_HASH_1))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn(