mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-21 23:28:04 +01:00
Add PUT /v2/account/number
This commit is contained in:
@@ -331,11 +331,12 @@ class AccountControllerTest {
|
||||
when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString()))
|
||||
.thenReturn(new AssessmentResult(true, ""));
|
||||
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
||||
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2));
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoicePrefixLimiter)
|
||||
.validate(SENDER_OVER_PREFIX.substring(0, 4 + 2));
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
@@ -571,7 +572,7 @@ class AccountControllerTest {
|
||||
@Test
|
||||
void testSendCodeRateLimited() {
|
||||
when(registrationServiceClient.createRegistrationSession(any(), any()))
|
||||
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10))));
|
||||
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true)));
|
||||
|
||||
Response response =
|
||||
resources.getJerseyTest()
|
||||
@@ -2050,7 +2051,7 @@ class AccountControllerTest {
|
||||
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
|
||||
|
||||
MockUtils.updateRateLimiterResponseToFail(
|
||||
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter);
|
||||
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
|
||||
|
||||
final Response response = resources.getJerseyTest()
|
||||
.target(String.format("/v1/accounts/account/%s", accountIdentifier))
|
||||
@@ -2115,7 +2116,7 @@ class AccountControllerTest {
|
||||
void testLookupUsernameRateLimited() throws RateLimitExceededException {
|
||||
final Duration expectedRetryAfter = Duration.ofSeconds(13);
|
||||
MockUtils.updateRateLimiterResponseToFail(
|
||||
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter);
|
||||
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
|
||||
final Response response = resources.getJerseyTest()
|
||||
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
|
||||
.request()
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.client.Invocation;
|
||||
import javax.ws.rs.core.HttpHeaders;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
import org.apache.http.HttpStatus;
|
||||
import org.glassfish.jersey.server.ServerProperties;
|
||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
|
||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class AccountControllerV2Test {
|
||||
|
||||
public static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format(
|
||||
PhoneNumberUtil.getInstance().getExampleNumber("US"),
|
||||
PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||
|
||||
private final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||
private final ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
|
||||
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
|
||||
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
|
||||
RegistrationRecoveryPasswordsManager.class);
|
||||
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
|
||||
RegistrationLockVerificationManager.class);
|
||||
private final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
|
||||
|
||||
private final ResourceExtension resources = ResourceExtension.builder()
|
||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||
.addProvider(AuthHelper.getAuthFilter())
|
||||
.addProvider(
|
||||
new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||
ImmutableSet.of(AuthenticatedAccount.class,
|
||||
DisabledPermittedAuthenticatedAccount.class)))
|
||||
.addProvider(new RateLimitExceededExceptionMapper())
|
||||
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
|
||||
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
|
||||
.setMapper(SystemMapper.getMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(
|
||||
new AccountControllerV2(accountsManager, changeNumberManager,
|
||||
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
|
||||
registrationLockVerificationManager, rateLimiters))
|
||||
.build();
|
||||
|
||||
@Nested
|
||||
class ChangeNumber {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
|
||||
|
||||
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer(
|
||||
(Answer<Account>) invocation -> {
|
||||
final Account account = invocation.getArgument(0, Account.class);
|
||||
final String number = invocation.getArgument(1, String.class);
|
||||
final String pniIdentityKey = invocation.getArgument(2, String.class);
|
||||
|
||||
final UUID uuid = account.getUuid();
|
||||
final List<Device> devices = account.getDevices();
|
||||
|
||||
final Account updatedAccount = mock(Account.class);
|
||||
when(updatedAccount.getUuid()).thenReturn(uuid);
|
||||
when(updatedAccount.getNumber()).thenReturn(number);
|
||||
when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
|
||||
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
|
||||
when(updatedAccount.getDevices()).thenReturn(devices);
|
||||
|
||||
for (long i = 1; i <= 3; i++) {
|
||||
final Optional<Device> d = account.getDevice(i);
|
||||
when(updatedAccount.getDevice(i)).thenReturn(d);
|
||||
}
|
||||
|
||||
return updatedAccount;
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void changeNumberSuccess() throws Exception {
|
||||
|
||||
when(registrationServiceClient.getSession(any(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
|
||||
|
||||
final AccountIdentityResponse accountIdentityResponse =
|
||||
resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||
.put(Entity.entity(
|
||||
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123",
|
||||
Collections.emptyList(),
|
||||
Collections.emptyMap(), Collections.emptyMap()),
|
||||
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
|
||||
|
||||
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
|
||||
any());
|
||||
|
||||
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
|
||||
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
|
||||
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
|
||||
}
|
||||
|
||||
@Test
|
||||
void unprocessableRequestJson() {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(unprocessableJson()))) {
|
||||
assertEquals(400, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void missingBasicAuthorization() {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request();
|
||||
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||
assertEquals(401, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void invalidBasicAuthorization() {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, "Basic but-invalid");
|
||||
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
|
||||
assertEquals(401, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void invalidRequestBody() {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
|
||||
assertEquals(422, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void rateLimitedNumber() throws Exception {
|
||||
doThrow(new RateLimitExceededException(null, true))
|
||||
.when(registrationLimiter).validate(anyString());
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||
assertEquals(429, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void registrationServiceTimeout() {
|
||||
when(registrationServiceClient.getSession(any(), any()))
|
||||
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus,
|
||||
final String message) {
|
||||
when(registrationServiceClient.getSession(any(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session)));
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||
assertEquals(expectedStatus, response.getStatus(), message);
|
||||
}
|
||||
}
|
||||
|
||||
static Stream<Arguments> registrationServiceSessionCheck() {
|
||||
return Stream.of(
|
||||
Arguments.of(null, 401, "session not found"),
|
||||
Arguments.of(new RegistrationSession("+18005551234", false), 400, "session number mismatch"),
|
||||
Arguments.of(new RegistrationSession(NEW_NUMBER, false), 401, "session not verified")
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(RegistrationLockError.class)
|
||||
void registrationLock(final RegistrationLockError error) throws Exception {
|
||||
when(registrationServiceClient.getSession(any(), any()))
|
||||
.thenReturn(
|
||||
CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
|
||||
|
||||
when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class)));
|
||||
|
||||
final Exception e = switch (error) {
|
||||
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
|
||||
case RATE_LIMITED -> new RateLimitExceededException(null, true);
|
||||
};
|
||||
doThrow(e)
|
||||
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||
assertEquals(error.getExpectedStatus(), response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void recoveryPasswordManagerVerificationTrue() throws Exception {
|
||||
when(registrationRecoveryPasswordsManager.verify(any(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(true));
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
final byte[] recoveryPassword = new byte[32];
|
||||
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
|
||||
assertEquals(200, response.getStatus());
|
||||
|
||||
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
|
||||
|
||||
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
|
||||
any());
|
||||
|
||||
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
|
||||
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
|
||||
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void recoveryPasswordManagerVerificationFalse() {
|
||||
when(registrationRecoveryPasswordsManager.verify(any(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(false));
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v2/accounts/number")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION,
|
||||
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(new byte[32], NEW_NUMBER)))) {
|
||||
assertEquals(403, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid request JSON with the given Recovery Password
|
||||
*/
|
||||
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword, final String newNumber) {
|
||||
return requestJson("", recoveryPassword, newNumber);
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid request JSON with the give session ID and recovery password
|
||||
*/
|
||||
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber) {
|
||||
return String.format("""
|
||||
{
|
||||
"sessionId": "%s",
|
||||
"recoveryPassword": "%s",
|
||||
"number": "%s",
|
||||
"reglock": "1234",
|
||||
"pniIdentityKey": "5678",
|
||||
"deviceMessages": [],
|
||||
"devicePniSignedPrekeys": {},
|
||||
"pniRegistrationIds": {}
|
||||
}
|
||||
""", encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber);
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid request JSON with the give session ID
|
||||
*/
|
||||
private static String requestJson(final String sessionId, final String newNumber) {
|
||||
return requestJson(sessionId, new byte[0], newNumber);
|
||||
}
|
||||
|
||||
/**
|
||||
* Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}, but that
|
||||
* fails validation
|
||||
*/
|
||||
private static String invalidRequestJson() {
|
||||
return """
|
||||
{
|
||||
"sessionId": null
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
/**
|
||||
* Request JSON that cannot be marshalled into
|
||||
* {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}
|
||||
*/
|
||||
private static String unprocessableJson() {
|
||||
return """
|
||||
{
|
||||
"sessionId": []
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
private static String encodeSessionId(final String sessionId) {
|
||||
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
|
||||
return Base64.getEncoder().encodeToString(recoveryPassword);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -86,7 +86,8 @@ class ChallengeControllerTest {
|
||||
""";
|
||||
|
||||
final Duration retryAfter = Duration.ofMinutes(17);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any());
|
||||
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
|
||||
.answerPushChallenge(any(), any());
|
||||
|
||||
final Response response = EXTENSION.target("/v1/challenge")
|
||||
.request()
|
||||
@@ -128,7 +129,8 @@ class ChallengeControllerTest {
|
||||
""";
|
||||
|
||||
final Duration retryAfter = Duration.ofMinutes(17);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any(), any());
|
||||
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
|
||||
.answerRecaptchaChallenge(any(), any(), any(), any());
|
||||
|
||||
final Response response = EXTENSION.target("/v1/challenge")
|
||||
.request()
|
||||
|
||||
@@ -255,7 +255,8 @@ class ProfileControllerTest {
|
||||
|
||||
@Test
|
||||
void testProfileGetByAciRateLimited() throws RateLimitExceededException {
|
||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
|
||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
|
||||
.validate(AuthHelper.VALID_UUID);
|
||||
|
||||
Response response= resources.getJerseyTest()
|
||||
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
|
||||
@@ -326,7 +327,8 @@ class ProfileControllerTest {
|
||||
|
||||
@Test
|
||||
void testProfileGetByPniRateLimited() throws RateLimitExceededException {
|
||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
|
||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
|
||||
.validate(AuthHelper.VALID_UUID);
|
||||
|
||||
Response response= resources.getJerseyTest()
|
||||
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
|
||||
|
||||
@@ -1,9 +1,26 @@
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.Base64;
|
||||
import java.util.UUID;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -20,25 +37,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.Base64;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class ProvisioningControllerTest {
|
||||
|
||||
@@ -101,7 +99,7 @@ class ProvisioningControllerTest {
|
||||
final String destination = UUID.randomUUID().toString();
|
||||
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO))
|
||||
doThrow(new RateLimitExceededException(Duration.ZERO, true))
|
||||
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
|
||||
|
||||
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
|
||||
|
||||
@@ -13,6 +13,7 @@ import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
@@ -37,6 +38,7 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
|
||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||
@@ -51,12 +53,18 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class RegistrationControllerTest {
|
||||
|
||||
private static final String NUMBER = "+18005551212";
|
||||
private static final String NUMBER = PhoneNumberUtil.getInstance().format(
|
||||
PhoneNumberUtil.getInstance().getExampleNumber("US"),
|
||||
PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||
|
||||
public static final String PASSWORD = "password";
|
||||
|
||||
private final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
|
||||
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
|
||||
@@ -66,7 +74,6 @@ class RegistrationControllerTest {
|
||||
private final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
|
||||
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
|
||||
private final RateLimiter pinLimiter = mock(RateLimiter.class);
|
||||
|
||||
private final ResourceExtension resources = ResourceExtension.builder()
|
||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||
@@ -76,14 +83,14 @@ class RegistrationControllerTest {
|
||||
.setMapper(SystemMapper.getMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(
|
||||
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
|
||||
registrationRecoveryPasswordsManager, rateLimiters))
|
||||
new RegistrationController(accountsManager,
|
||||
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
|
||||
registrationLockVerificationManager, rateLimiters))
|
||||
.build();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
|
||||
when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -130,25 +137,23 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(invalidRequestJson()))) {
|
||||
assertEquals(422, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void rateLimitedSession() throws Exception {
|
||||
final String sessionId = "sessionId";
|
||||
void rateLimitedNumber() throws Exception {
|
||||
doThrow(RateLimitExceededException.class)
|
||||
.when(registrationLimiter).validate(encodeSessionId(sessionId));
|
||||
.when(registrationLimiter).validate(NUMBER);
|
||||
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
try (Response response = request.post(Entity.json(requestJson(sessionId)))) {
|
||||
assertEquals(413, response.getStatus());
|
||||
// In the future, change to assertEquals(429, response.getStatus());
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||
assertEquals(429, response.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +165,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||
}
|
||||
@@ -174,7 +179,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
||||
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||
}
|
||||
@@ -190,7 +195,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||
assertEquals(expectedStatus, response.getStatus(), message);
|
||||
}
|
||||
@@ -214,7 +219,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
final byte[] recoveryPassword = new byte[32];
|
||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
|
||||
assertEquals(200, response.getStatus());
|
||||
@@ -229,7 +234,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
||||
assertEquals(403, response.getStatus());
|
||||
}
|
||||
@@ -245,7 +250,7 @@ class RegistrationControllerTest {
|
||||
|
||||
final Exception e = switch (error) {
|
||||
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
|
||||
case RATE_LIMITED -> new RateLimitExceededException(null);
|
||||
case RATE_LIMITED -> new RateLimitExceededException(null, true);
|
||||
};
|
||||
doThrow(e)
|
||||
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
|
||||
@@ -253,7 +258,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||
assertEquals(error.getExpectedStatus(), response.getStatus());
|
||||
}
|
||||
@@ -286,7 +291,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
|
||||
assertEquals(expectedStatus, response.getStatus());
|
||||
}
|
||||
@@ -294,7 +299,7 @@ class RegistrationControllerTest {
|
||||
|
||||
// this is functionally the same as deviceTransferAvailable(existingAccount=false)
|
||||
@Test
|
||||
void success() throws Exception {
|
||||
void registrationSuccess() throws Exception {
|
||||
when(registrationServiceClient.getSession(any(), any()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true))));
|
||||
when(accountsManager.create(any(), any(), any(), any(), any()))
|
||||
@@ -303,7 +308,7 @@ class RegistrationControllerTest {
|
||||
final Invocation.Builder request = resources.getJerseyTest()
|
||||
.target("/v1/registration")
|
||||
.request()
|
||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
||||
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||
assertEquals(200, response.getStatus());
|
||||
}
|
||||
@@ -365,13 +370,8 @@ class RegistrationControllerTest {
|
||||
""";
|
||||
}
|
||||
|
||||
private static String authorizationHeader(final String number) {
|
||||
return "Basic " + Base64.getEncoder().encodeToString(
|
||||
String.format("%s:password", number).getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
private static String encodeSessionId(final String sessionId) {
|
||||
return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
||||
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
|
||||
|
||||
@@ -72,11 +72,11 @@ public class RateLimitedByIpTest {
|
||||
public void testRateLimits() throws Exception {
|
||||
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ public class RateLimitedByIpTest {
|
||||
validateSuccess("/test/loose", "");
|
||||
|
||||
// also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.anyString());
|
||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.anyString());
|
||||
validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||
validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR);
|
||||
validateSuccess("/test/loose", "");
|
||||
|
||||
@@ -340,7 +340,7 @@ class KeysControllerTest {
|
||||
@Test
|
||||
void testGetKeysRateLimited() throws RateLimitExceededException {
|
||||
Duration retryAfter = Duration.ofSeconds(31);
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString());
|
||||
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString());
|
||||
|
||||
Response result = resources.getJerseyTest()
|
||||
.target(String.format("/v2/keys/%s/*", EXISTS_PNI))
|
||||
|
||||
@@ -60,11 +60,12 @@ public final class MockUtils {
|
||||
final RateLimiters rateLimitersMock,
|
||||
final RateLimiters.Handle handle,
|
||||
final String input,
|
||||
final Duration retryAfter) {
|
||||
final Duration retryAfter,
|
||||
final boolean legacyStatusCode) {
|
||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
|
||||
try {
|
||||
doThrow(new RateLimitExceededException(retryAfter)).when(mockRateLimiter).validate(eq(input));
|
||||
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
|
||||
} catch (final RateLimitExceededException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user