Check for IdentityType.PNI in OptionalAccess#verify

This commit is contained in:
Chris Eager
2024-07-23 15:15:27 -05:00
committed by Chris Eager
parent 8afc0e6ab2
commit e4ffc932a9
7 changed files with 94 additions and 41 deletions

View File

@@ -16,11 +16,16 @@ import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import javax.ws.rs.WebApplicationException;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -32,15 +37,17 @@ class OptionalAccessTest {
void verify(final Optional<Account> requestAccount,
final Optional<Anonymous> accessKey,
final Optional<Account> targetAccount,
final ServiceIdentifier targetIdentifier,
final String deviceSelector,
final OptionalInt expectedStatusCode) {
expectedStatusCode.ifPresentOrElse(statusCode -> {
final WebApplicationException webApplicationException = assertThrows(WebApplicationException.class,
() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector));
() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, targetIdentifier, deviceSelector));
assertEquals(statusCode, webApplicationException.getResponse().getStatus());
}, () -> assertDoesNotThrow(() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector)));
}, () -> assertDoesNotThrow(() ->
OptionalAccess.verify(requestAccount, accessKey, targetAccount, targetIdentifier, deviceSelector)));
}
private static List<Arguments> verify() {
@@ -53,28 +60,39 @@ class OptionalAccessTest {
new Anonymous(Base64.getEncoder().encodeToString((unidentifiedAccessKey + "-incorrect").getBytes()));
final Account targetAccount = mock(Account.class);
final ServiceIdentifier targetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier targetAccountPniIdentifier = new PniServiceIdentifier(UUID.randomUUID());
when(targetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(targetAccount.getUnidentifiedAccessKey())
.thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8)));
when(targetAccount.isIdentifiedBy(targetAccountAciIdentifier)).thenReturn(true);
when(targetAccount.isIdentifiedBy(targetAccountPniIdentifier)).thenReturn(true);
final Account allowAllTargetAccount = mock(Account.class);
final ServiceIdentifier allowAllTargetAccountPniIdentifier = new PniServiceIdentifier(UUID.randomUUID());
when(allowAllTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(allowAllTargetAccount.isUnrestrictedUnidentifiedAccess()).thenReturn(true);
when(allowAllTargetAccount.isIdentifiedBy(allowAllTargetAccountPniIdentifier)).thenReturn(true);
final Account noUakTargetAccount = mock(Account.class);
final ServiceIdentifier noUakTargetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(noUakTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(noUakTargetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.empty());
when(noUakTargetAccount.isIdentifiedBy(noUakTargetAccountAciIdentifier)).thenReturn(true);
final Account inactiveTargetAccount = mock(Account.class);
final ServiceIdentifier inactiveTargetAccountAciIdentifier = new AciServiceIdentifier(UUID.randomUUID());
when(inactiveTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(inactiveTargetAccount.getUnidentifiedAccessKey())
.thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8)));
when(inactiveTargetAccount.isIdentifiedBy(inactiveTargetAccountAciIdentifier)).thenReturn(true);
return List.of(
// Unidentified caller; correct UAK
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(targetAccount),
targetAccountAciIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
@@ -82,6 +100,7 @@ class OptionalAccessTest {
Arguments.of(Optional.of(mock(Account.class)),
Optional.empty(),
Optional.of(targetAccount),
targetAccountAciIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
@@ -89,6 +108,7 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.empty(),
Optional.empty(),
new AciServiceIdentifier(UUID.randomUUID()),
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(401)),
@@ -96,6 +116,7 @@ class OptionalAccessTest {
Arguments.of(Optional.of(mock(Account.class)),
Optional.empty(),
Optional.empty(),
new AciServiceIdentifier(UUID.randomUUID()),
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(404)),
@@ -103,6 +124,7 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(targetAccount),
targetAccountAciIdentifier,
String.valueOf(Device.PRIMARY_ID + 1),
OptionalInt.of(401)),
@@ -110,6 +132,7 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.of(incorrectUakHeader),
Optional.of(targetAccount),
targetAccountAciIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(401)),
@@ -117,13 +140,15 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(noUakTargetAccount),
noUakTargetAccountAciIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(401)),
// Unidentified caller; target account found, allows unrestricted unidentified access
// Unidentified caller; target account found, allows unrestricted unidentified access, so PNI target doesn't matter
Arguments.of(Optional.empty(),
Optional.of(incorrectUakHeader),
Optional.of(allowAllTargetAccount),
allowAllTargetAccountPniIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
@@ -131,6 +156,7 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(inactiveTargetAccount),
inactiveTargetAccountAciIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
@@ -138,8 +164,35 @@ class OptionalAccessTest {
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(targetAccount),
targetAccountAciIdentifier,
"not a valid identifier",
OptionalInt.of(422))
OptionalInt.of(422)),
// Unidentified caller; target account found, but PNI identifier
Arguments.of(Optional.empty(),
Optional.of(correctUakHeader),
Optional.of(targetAccount),
targetAccountPniIdentifier,
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(401))
);
}
@Test
void testTargetIdentifierIllegalArgument() {
final String unidentifiedAccessKey = RandomStringUtils.randomAlphanumeric(16);
final Anonymous correctUakHeader =
new Anonymous(Base64.getEncoder().encodeToString(unidentifiedAccessKey.getBytes()));
final Account targetAccount = mock(Account.class);
when(targetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(targetAccount.getUnidentifiedAccessKey())
.thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8)));
assertThrows(IllegalArgumentException.class,
() -> OptionalAccess.verify(Optional.empty(), Optional.of(correctUakHeader), Optional.of(targetAccount),
new AciServiceIdentifier(UUID.randomUUID())));
}
}

View File

@@ -38,7 +38,6 @@ import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.core.MediaType;
@@ -51,7 +50,6 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
@@ -227,7 +225,9 @@ class KeysControllerTest {
when(sampleDevice4.getId()).thenReturn(sampleDevice4Id);
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.isIdentifiedBy(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(true);
when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI);
when(existsAccount.isIdentifiedBy(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(true);
when(existsAccount.getIdentifier(IdentityType.ACI)).thenReturn(EXISTS_UUID);
when(existsAccount.getIdentifier(IdentityType.PNI)).thenReturn(EXISTS_PNI);
when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice));

View File

@@ -202,6 +202,8 @@ class ProfileControllerTest {
when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty());
when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH));
when(profileAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
when(profileAccount.isIdentifiedBy(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO)))).thenReturn(true);
when(profileAccount.isIdentifiedBy(eq(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO)))).thenReturn(true);
capabilitiesAccount = mock(Account.class);
@@ -1166,6 +1168,7 @@ class ProfileControllerTest {
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version));
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true);
final Instant expiration = Instant.now().plus(org.whispersystems.textsecuregcm.util.ProfileHelper.EXPIRING_PROFILE_KEY_CREDENTIAL_EXPIRATION)
.truncatedTo(ChronoUnit.DAYS);
@@ -1231,6 +1234,7 @@ class ProfileControllerTest {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
when(account.isIdentifiedBy(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(true);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(account));
when(profilesManager.get(AuthHelper.VALID_UUID, version)).thenReturn(Optional.of(versionedProfile));