Add svr3 share-set store/retrieve

This commit is contained in:
ravi-signal
2024-05-17 10:45:18 -05:00
committed by GitHub
parent 1182d159aa
commit ce1c5be940
18 changed files with 493 additions and 92 deletions

View File

@@ -49,12 +49,15 @@ class RegistrationLockVerificationManagerTest {
private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private final ExternalServiceCredentialsGenerator svr2CredentialsGenerator = mock(
ExternalServiceCredentialsGenerator.class);
private final ExternalServiceCredentialsGenerator svr3CredentialsGenerator = mock(
ExternalServiceCredentialsGenerator.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, svr2CredentialsGenerator, registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
accountsManager, clientPresenceManager, svr2CredentialsGenerator, svr3CredentialsGenerator,
registrationRecoveryPasswordsManager, pushNotificationManager, rateLimiters);
private final RateLimiter pinLimiter = mock(RateLimiter.class);

View File

@@ -15,10 +15,14 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.entities.AuthCheckResponseV2;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import javax.ws.rs.core.Response;
import java.util.Map;
import java.util.stream.Collectors;
@ExtendWith(DropwizardExtensionsSupport.class)
public class SecureValueRecovery2ControllerTest extends SecureValueRecoveryControllerBaseTest {
@@ -51,4 +55,16 @@ public class SecureValueRecovery2ControllerTest extends SecureValueRecoveryContr
protected SecureValueRecovery2ControllerTest() {
super("/v2", ACCOUNTS_MANAGER, CLOCK, RESOURCES, CREDENTIAL_GENERATOR);
}
@Override
Map<String, CheckStatus> parseCheckResponse(final Response response) {
final AuthCheckResponseV2 authCheckResponseV2 = response.readEntity(AuthCheckResponseV2.class);
return authCheckResponseV2.matches().entrySet().stream().collect(Collectors.toMap(
Map.Entry::getKey, e -> switch (e.getValue()) {
case MATCH -> CheckStatus.MATCH;
case INVALID -> CheckStatus.INVALID;
case NO_MATCH -> CheckStatus.NO_MATCH;
}
));
}
}

View File

@@ -6,20 +6,50 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.util.MockUtils.randomSecretBytes;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery3Configuration;
import org.whispersystems.textsecuregcm.entities.AuthCheckRequest;
import org.whispersystems.textsecuregcm.entities.AuthCheckResponseV3;
import org.whispersystems.textsecuregcm.entities.SetShareSetRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import static org.mockito.Mockito.mock;
import static org.whispersystems.textsecuregcm.util.MockUtils.randomSecretBytes;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ExtendWith(DropwizardExtensionsSupport.class)
public class SecureValueRecovery3ControllerTest extends SecureValueRecoveryControllerBaseTest {
@@ -44,6 +74,7 @@ public class SecureValueRecovery3ControllerTest extends SecureValueRecoveryContr
private static final ResourceExtension RESOURCES = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(CONTROLLER)
@@ -52,4 +83,100 @@ public class SecureValueRecovery3ControllerTest extends SecureValueRecoveryContr
protected SecureValueRecovery3ControllerTest() {
super("/v3", ACCOUNTS_MANAGER, CLOCK, RESOURCES, CREDENTIAL_GENERATOR);
}
@Override
Map<String, CheckStatus> parseCheckResponse(final Response response) {
final AuthCheckResponseV3 authCheckResponse = response.readEntity(AuthCheckResponseV3.class);
assertFalse(authCheckResponse.matches()
.values().stream()
.anyMatch(r -> r.status() == AuthCheckResponseV3.CredentialStatus.MATCH && r.shareSet() == null),
"SVR3 matches must contain a non-empty share-set");
return authCheckResponse.matches().entrySet().stream().collect(Collectors.toMap(
Map.Entry::getKey, e -> switch (e.getValue().status()) {
case MATCH -> CheckStatus.MATCH;
case INVALID -> CheckStatus.INVALID;
case NO_MATCH -> CheckStatus.NO_MATCH;
}
));
}
public static Stream<Arguments> checkShareSet() {
byte[] shareSet = TestRandomUtil.nextBytes(100);
return Stream.of(
Arguments.of(shareSet, AuthCheckResponseV3.Result.match(shareSet)),
Arguments.of(null, AuthCheckResponseV3.Result.match(null)));
}
@ParameterizedTest
@MethodSource
public void checkShareSet(@Nullable byte[] shareSet, AuthCheckResponseV3.Result expectedResult) {
final String e164 = "+18005550101";
final UUID uuid = UUID.randomUUID();
final String token = token(uuid, day(10));
CLOCK.setTimeMillis(day(11));
final Account a = mock(Account.class);
when(a.getUuid()).thenReturn(uuid);
when(a.getSvr3ShareSet()).thenReturn(shareSet);
when(ACCOUNTS_MANAGER.getByE164(e164)).thenReturn(Optional.of(a));
final AuthCheckRequest in = new AuthCheckRequest(e164, Collections.singletonList(token));
final Response response = RESOURCES.getJerseyTest()
.target("/v3/backup/auth/check")
.request()
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
assertEquals(200, response.getStatus());
AuthCheckResponseV3 checkResponse = response.readEntity(AuthCheckResponseV3.class);
assertEquals(checkResponse.matches().size(), 1);
assertEquals(checkResponse.matches().get(token).status(), expectedResult.status());
assertArrayEquals(checkResponse.matches().get(token).shareSet(), expectedResult.shareSet());
}
}
@Test
public void setShareSet() {
final Account a = mock(Account.class);
when(ACCOUNTS_MANAGER.update(any(), any())).thenAnswer(invocation -> {
final Consumer<Account> updater = invocation.getArgument(1);
updater.accept(a);
return null;
});
byte[] shareSet = TestRandomUtil.nextBytes(SetShareSetRequest.SHARE_SET_SIZE);
final Response response = RESOURCES.getJerseyTest()
.target("/v3/backup/share-set")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new SetShareSetRequest(shareSet), MediaType.APPLICATION_JSON));
assertEquals(204, response.getStatus());
verify(a, times(1)).setSvr3ShareSet(eq(shareSet));
}
static Stream<Arguments> requestParsing() {
return Stream.of(
Arguments.of("", 422),
Arguments.of(null, 422),
Arguments.of("abc**", 400), // bad base64
Arguments.of(Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(SetShareSetRequest.SHARE_SET_SIZE - 1)), 422),
Arguments.of(Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(SetShareSetRequest.SHARE_SET_SIZE)), 204));
}
@ParameterizedTest
@MethodSource
public void requestParsing(String shareSet, int responseCode) {
final Response response = RESOURCES.getJerseyTest()
.target("/v3/backup/share-set")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity("""
{"shareSet": "%s"}
""".formatted(shareSet), MediaType.APPLICATION_JSON));
assertEquals(responseCode, response.getStatus());
}
}

View File

@@ -23,10 +23,10 @@ import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.entities.AuthCheckRequest;
import org.whispersystems.textsecuregcm.entities.AuthCheckResponse;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
abstract class SecureValueRecoveryControllerBaseTest {
@@ -64,20 +64,27 @@ abstract class SecureValueRecoveryControllerBaseTest {
this.clock = mutableClock;
}
enum CheckStatus {
MATCH,
NO_MATCH,
INVALID
}
abstract Map<String, CheckStatus> parseCheckResponse(Response response);
@Test
public void testOneMatch() throws Exception {
validate(Map.of(
token(USER_1, day(1)), AuthCheckResponse.Result.MATCH,
token(USER_2, day(1)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(1)), AuthCheckResponse.Result.NO_MATCH
token(USER_1, day(1)), CheckStatus.MATCH,
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@Test
public void testNoMatch() throws Exception {
validate(Map.of(
token(USER_2, day(1)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(1)), AuthCheckResponse.Result.NO_MATCH
token(USER_2, day(1)), CheckStatus.NO_MATCH,
token(USER_3, day(1)), CheckStatus.NO_MATCH
), day(2));
}
@@ -89,35 +96,35 @@ abstract class SecureValueRecoveryControllerBaseTest {
final String fakeToken = token(new ExternalServiceCredentials(user2Cred.username(), user3Cred.password()));
validate(Map.of(
token(user1Cred), AuthCheckResponse.Result.MATCH,
token(user2Cred), AuthCheckResponse.Result.NO_MATCH,
fakeToken, AuthCheckResponse.Result.INVALID
token(user1Cred), CheckStatus.MATCH,
token(user2Cred), CheckStatus.NO_MATCH,
fakeToken, CheckStatus.INVALID
), day(2));
}
@Test
public void testSomeExpired() throws Exception {
validate(Map.of(
token(USER_1, day(100)), AuthCheckResponse.Result.MATCH,
token(USER_2, day(100)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(10)), AuthCheckResponse.Result.INVALID,
token(USER_3, day(20)), AuthCheckResponse.Result.INVALID
token(USER_1, day(100)), CheckStatus.MATCH,
token(USER_2, day(100)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID,
token(USER_3, day(20)), CheckStatus.INVALID
), day(110));
}
@Test
public void testSomeHaveNewerVersions() throws Exception {
validate(Map.of(
token(USER_1, day(10)), AuthCheckResponse.Result.INVALID,
token(USER_1, day(20)), AuthCheckResponse.Result.MATCH,
token(USER_2, day(10)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(20)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(10)), AuthCheckResponse.Result.INVALID
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
), day(25));
}
private void validate(
final Map<String, AuthCheckResponse.Result> expected,
final Map<String, CheckStatus> expected,
final long nowMillis) throws Exception {
clock.setTimeMillis(nowMillis);
final AuthCheckRequest request = new AuthCheckRequest(E164_VALID, List.copyOf(expected.keySet()));
@@ -125,20 +132,20 @@ abstract class SecureValueRecoveryControllerBaseTest {
.request()
.post(Entity.entity(request, MediaType.APPLICATION_JSON));
try (response) {
final AuthCheckResponse res = response.readEntity(AuthCheckResponse.class);
assertEquals(200, response.getStatus());
assertEquals(expected, res.matches());
final Map<String, CheckStatus> res = parseCheckResponse(response);
assertEquals(expected, res);
}
}
@Test
public void testHttpResponseCodeSuccess() throws Exception {
final Map<String, AuthCheckResponse.Result> expected = Map.of(
token(USER_1, day(10)), AuthCheckResponse.Result.INVALID,
token(USER_1, day(20)), AuthCheckResponse.Result.MATCH,
token(USER_2, day(10)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(20)), AuthCheckResponse.Result.NO_MATCH,
token(USER_3, day(10)), AuthCheckResponse.Result.INVALID
final Map<String, CheckStatus> expected = Map.of(
token(USER_1, day(10)), CheckStatus.INVALID,
token(USER_1, day(20)), CheckStatus.MATCH,
token(USER_2, day(10)), CheckStatus.NO_MATCH,
token(USER_3, day(20)), CheckStatus.NO_MATCH,
token(USER_3, day(10)), CheckStatus.INVALID
);
clock.setTimeMillis(day(25));
@@ -151,9 +158,8 @@ abstract class SecureValueRecoveryControllerBaseTest {
.post(Entity.entity(in, MediaType.APPLICATION_JSON));
try (response) {
final AuthCheckResponse res = response.readEntity(AuthCheckResponse.class);
assertEquals(200, response.getStatus());
assertEquals(expected, res.matches());
assertEquals(expected, parseCheckResponse(response));
}
}
@@ -252,6 +258,35 @@ abstract class SecureValueRecoveryControllerBaseTest {
}
}
@Test
public void testAcceptsPasswordsOrTokens() {
final Response passwordsResponse = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"passwords": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (passwordsResponse) {
assertEquals(200, passwordsResponse.getStatus());
}
final Response tokensResponse = resourceExtension.getJerseyTest()
.target(pathPrefix + "/backup/auth/check")
.request()
.post(Entity.entity("""
{
"number": "+18005550123",
"tokens": ["aaa:bbb"]
}
""", MediaType.APPLICATION_JSON));
try (tokensResponse) {
assertEquals(200, tokensResponse.getStatus());
}
}
@Test
public void testHttpResponseCodeWhenNotAJson() throws Exception {
final Response response = resourceExtension.getJerseyTest()
@@ -264,11 +299,11 @@ abstract class SecureValueRecoveryControllerBaseTest {
}
}
private String token(final UUID uuid, final long timeMillis) {
String token(final UUID uuid, final long timeMillis) {
return token(credentials(uuid, timeMillis));
}
private static String token(final ExternalServiceCredentials credentials) {
static String token(final ExternalServiceCredentials credentials) {
return credentials.username() + ":" + credentials.password();
}
@@ -277,13 +312,14 @@ abstract class SecureValueRecoveryControllerBaseTest {
return credentialsGenerator.generateForUuid(uuid);
}
private static long day(final int n) {
static long day(final int n) {
return TimeUnit.DAYS.toMillis(n);
}
private static Account account(final UUID uuid) {
final Account a = new Account();
a.setUuid(uuid);
a.setSvr3ShareSet(TestRandomUtil.nextBytes(100));
return a;
}
}

View File

@@ -417,12 +417,15 @@ class AccountsTest {
}
@Test
void testReclaimAccountPreservesBcr() {
void testReclaimAccountPreservesFields() {
final String e164 = "+14151112222";
final UUID existingUuid = UUID.randomUUID();
final Account existingAccount =
generateAccount(e164, existingUuid, UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
// the backup credential request and share-set are always preserved across account reclaims
existingAccount.setBackupCredentialRequest(TestRandomUtil.nextBytes(32));
existingAccount.setSvr3ShareSet(TestRandomUtil.nextBytes(100));
createAccount(existingAccount);
final Account secondAccount =
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
@@ -431,6 +434,7 @@ class AccountsTest {
final Account reclaimed = accounts.getByAccountIdentifier(existingUuid).get();
assertThat(reclaimed.getBackupCredentialRequest()).isEqualTo(existingAccount.getBackupCredentialRequest());
assertThat(reclaimed.getSvr3ShareSet()).isEqualTo(existingAccount.getSvr3ShareSet());
}
@Test