Send "account already exists" flag when creating registration sessions

This commit is contained in:
Jon Chambers
2023-03-14 17:25:49 -04:00
committed by Jon Chambers
parent 2052e62c01
commit 35606a9afd
7 changed files with 96 additions and 39 deletions

View File

@@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
@@ -386,7 +387,7 @@ class AccountControllerTest {
@Test
void testGetFcmPreauth() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@@ -401,7 +402,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any());
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture());
@@ -411,7 +412,7 @@ class AccountControllerTest {
@Test
void testGetFcmPreauthIvoryCoast() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
Response response = resources.getJerseyTest()
@@ -424,7 +425,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), any());
eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture());
@@ -434,7 +435,7 @@ class AccountControllerTest {
@Test
void testGetApnPreauth() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@@ -449,7 +450,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any());
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@@ -459,7 +460,7 @@ class AccountControllerTest {
@Test
void testGetApnPreauthExplicitVoip() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@@ -475,7 +476,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any());
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@@ -485,7 +486,7 @@ class AccountControllerTest {
@Test
void testGetApnPreauthExplicitNoVoip() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@@ -501,7 +502,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any());
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture());
@@ -546,7 +547,7 @@ class AccountControllerTest {
void testGetPreauthExistingSession() throws NumberParseException {
final String existingPushCode = "existing-push-code";
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(
@@ -561,7 +562,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient, never()).createRegistrationSession(any(), any());
verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@@ -571,7 +572,7 @@ class AccountControllerTest {
@Test
void testGetPreauthExistingSessionWithoutPushCode() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(
@@ -586,7 +587,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient, never()).createRegistrationSession(any(), any());
verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@@ -624,7 +625,7 @@ class AccountControllerTest {
void testSendCode() {
final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8);
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
@@ -648,7 +649,7 @@ class AccountControllerTest {
@Test
void testSendCodeRateLimited() {
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true)));
Response response =
@@ -709,7 +710,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -732,7 +733,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -785,7 +786,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -809,7 +810,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -834,7 +835,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -944,7 +945,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -975,7 +976,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -998,7 +999,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =
@@ -2092,7 +2093,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any()))
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId));
Response response =

View File

@@ -11,7 +11,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@@ -19,6 +21,8 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.i18n.phonenumbers.NumberParseException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
@@ -46,6 +50,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.captcha.AssessmentResult;
import org.whispersystems.textsecuregcm.captcha.RegistrationCaptchaManager;
@@ -62,6 +67,8 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceException;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException;
import org.whispersystems.textsecuregcm.registration.VerificationSession;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -81,6 +88,7 @@ class VerificationControllerTest {
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final Clock clock = Clock.systemUTC();
private final RateLimiter captchaLimiter = mock(RateLimiter.class);
@@ -96,7 +104,7 @@ class VerificationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, clock))
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, clock))
.build();
@BeforeEach
@@ -105,6 +113,8 @@ class VerificationControllerTest {
.thenReturn(captchaLimiter);
when(rateLimiters.getVerificationPushChallengeLimiter())
.thenReturn(pushChallengeLimiter);
when(accountsManager.getByE164(any())).thenReturn(Optional.empty());
}
@ParameterizedTest
@@ -153,7 +163,7 @@ class VerificationControllerTest {
@Test
void createSessionRateLimited() {
when(registrationServiceClient.createRegistrationSessionSession(any(), any()))
when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true)));
final Invocation.Builder request = resources.getJerseyTest()
@@ -167,7 +177,7 @@ class VerificationControllerTest {
@Test
void createSessionRegistrationServiceError() {
when(registrationServiceClient.createRegistrationSessionSession(any(), any()))
when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error")));
final Invocation.Builder request = resources.getJerseyTest()
@@ -183,7 +193,7 @@ class VerificationControllerTest {
@MethodSource
void createSessionSuccess(final String pushToken, final String pushTokenType,
final List<VerificationSession.Information> expectedRequestedInformation) {
when(registrationServiceClient.createRegistrationSessionSession(any(), any()))
when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@@ -214,6 +224,37 @@ class VerificationControllerTest {
);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void createSessionReregistration(final boolean isReregistration) throws NumberParseException {
when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.getByE164(NUMBER))
.thenReturn(isReregistration ? Optional.of(mock(Account.class)) : Optional.empty());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session")
.request()
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1");
try (final Response response = request.post(Entity.json(createSessionJson(NUMBER, null, null)))) {
assertEquals(HttpStatus.SC_OK, response.getStatus());
verify(registrationServiceClient).createRegistrationSessionSession(
eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)),
eq(isReregistration),
any()
);
}
}
@Test
void patchSessionMalformedId() {
final String invalidSessionId = "()()()";