Add client challenges for prekey and message rate limiters

This commit is contained in:
Jon Chambers
2021-05-11 17:21:32 -04:00
committed by GitHub
parent 5752853bba
commit 46110d4d65
46 changed files with 2289 additions and 255 deletions

View File

@@ -347,7 +347,6 @@ class DynamicConfigurationTest {
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getMaxCardinality()).isEqualTo(100);
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getTtl()).isEqualTo(Duration.ofDays(1));
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getTtlJitter()).isEqualTo(Duration.ofDays(1));
}
{
@@ -355,15 +354,46 @@ class DynamicConfigurationTest {
"limits:\n"
+ " unsealedSenderNumber:\n"
+ " maxCardinality: 99\n"
+ " ttl: PT23H\n"
+ " ttlJitter: PT22H";
+ " ttl: PT23H";
final CardinalityRateLimitConfiguration unsealedSenderNumber = DynamicConfigurationManager.OBJECT_MAPPER
.readValue(limitsConfig, DynamicConfiguration.class)
.getLimits().getUnsealedSenderNumber();
assertThat(unsealedSenderNumber.getMaxCardinality()).isEqualTo(99);
assertThat(unsealedSenderNumber.getTtl()).isEqualTo(Duration.ofHours(23));
assertThat(unsealedSenderNumber.getTtlJitter()).isEqualTo(Duration.ofHours(22));
}
}
@Test
void testParseRateLimitReset() throws JsonProcessingException {
{
final String emptyConfigYaml = "test: true";
final DynamicConfiguration emptyConfig = DynamicConfigurationManager.OBJECT_MAPPER.readValue(
emptyConfigYaml, DynamicConfiguration.class);
assertThat(emptyConfig.getRateLimitChallengeConfiguration().getClientSupportedVersions()).isEmpty();
assertThat(emptyConfig.getRateLimitChallengeConfiguration().isPreKeyLimitEnforced()).isFalse();
assertThat(emptyConfig.getRateLimitChallengeConfiguration().isUnsealedSenderLimitEnforced()).isFalse();
}
{
final String rateLimitChallengeConfig =
"rateLimitChallenge:\n"
+ " preKeyLimitEnforced: true\n"
+ " clientSupportedVersions:\n"
+ " IOS: 5.1.0\n"
+ " ANDROID: 5.2.0\n"
+ " DESKTOP: 5.0.0";
DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration = DynamicConfigurationManager.OBJECT_MAPPER
.readValue(rateLimitChallengeConfig, DynamicConfiguration.class)
.getRateLimitChallengeConfiguration();
final Map<ClientPlatform, Semver> clientSupportedVersions = rateLimitChallengeConfiguration.getClientSupportedVersions();
assertThat(clientSupportedVersions.get(ClientPlatform.IOS)).isEqualTo(new Semver("5.1.0"));
assertThat(clientSupportedVersions.get(ClientPlatform.ANDROID)).isEqualTo(new Semver("5.2.0"));
assertThat(clientSupportedVersions.get(ClientPlatform.DESKTOP)).isEqualTo(new Semver("5.0.0"));
assertThat(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).isTrue();
assertThat(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).isFalse();
}
}
}

View File

@@ -0,0 +1,200 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Set;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(Set.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RetryLaterExceptionMapper())
.addResource(challengeController)
.build();
@AfterEach
void teardown() {
reset(rateLimitChallengeManager);
}
@Test
void testHandlePushChallenge() throws RateLimitExceededException {
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(200, response.getStatus());
verify(rateLimitChallengeManager).answerPushChallenge(AuthHelper.VALID_ACCOUNT, "Hello I am a push challenge token");
}
@Test
void testHandlePushChallengeRateLimited() throws RateLimitExceededException {
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any());
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(413, response.getStatus());
assertEquals(String.valueOf(retryAfter.toSeconds()), response.getHeaderString("Retry-After"));
}
@Test
void testHandleRecaptcha() throws RateLimitExceededException {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(200, response.getStatus());
verify(rateLimitChallengeManager).answerRecaptchaChallenge(AuthHelper.VALID_ACCOUNT, "The value of the solved captcha token", "10.0.0.1");
}
@Test
void testHandleRecaptchaRateLimited() throws RateLimitExceededException {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any());
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(413, response.getStatus());
assertEquals(String.valueOf(retryAfter.toSeconds()), response.getHeaderString("Retry-After"));
}
@Test
void testHandleRecaptchaNoForwardedFor() {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(400, response.getStatus());
verifyZeroInteractions(rateLimitChallengeManager);
}
@Test
void testHandleUnrecognizedAnswer() {
final String unrecognizedJson = "{\n"
+ " \"type\": \"unrecognized\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(unrecognizedJson));
assertEquals(400, response.getStatus());
verifyZeroInteractions(rateLimitChallengeManager);
}
@Test
void testRequestPushChallenge() throws NotPushRegisteredException {
{
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.post(Entity.text(""));
assertEquals(200, response.getStatus());
}
{
doThrow(NotPushRegisteredException.class).when(rateLimitChallengeManager).sendPushChallenge(AuthHelper.VALID_ACCOUNT_TWO);
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER_TWO, AuthHelper.VALID_PASSWORD_TWO))
.post(Entity.text(""));
assertEquals(404, response.getStatus());
}
}
@Test
void testValidationError() {
final String unrecognizedJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(unrecognizedJson));
assertEquals(422, response.getStatus());
}
}

View File

@@ -5,9 +5,16 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -16,12 +23,6 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import java.util.concurrent.ScheduledExecutorService;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
private MessageController messageController;
@@ -35,8 +36,10 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
mock(ReceiptSender.class),
mock(AccountsManager.class),
mock(MessagesManager.class),
mock(UnsealedSenderRateLimiter.class),
mock(ApnFallbackManager.class),
mock(DynamicConfigurationManager.class),
mock(RateLimitChallengeManager.class),
getRedisCluster(),
mock(ScheduledExecutorService.class));
}

View File

@@ -0,0 +1,63 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import static org.junit.jupiter.api.Assertions.*;
class AnswerChallengeRequestTest {
@Test
void parse() throws JsonProcessingException {
{
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final AnswerChallengeRequest answerChallengeRequest =
SystemMapper.getMapper().readValue(pushChallengeJson, AnswerChallengeRequest.class);
assertTrue(answerChallengeRequest instanceof AnswerPushChallengeRequest);
assertEquals("Hello I am a push challenge token",
((AnswerPushChallengeRequest) answerChallengeRequest).getChallenge());
}
{
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final AnswerChallengeRequest answerChallengeRequest =
SystemMapper.getMapper().readValue(recaptchaChallengeJson, AnswerChallengeRequest.class);
assertTrue(answerChallengeRequest instanceof AnswerRecaptchaChallengeRequest);
assertEquals("A server-generated token",
((AnswerRecaptchaChallengeRequest) answerChallengeRequest).getToken());
assertEquals("The value of the solved captcha token",
((AnswerRecaptchaChallengeRequest) answerChallengeRequest).getCaptcha());
}
{
final String unrecognizedTypeJson = "{\n"
+ " \"type\": \"unrecognized\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
assertThrows(InvalidTypeIdException.class,
() -> SystemMapper.getMapper().readValue(unrecognizedTypeJson, AnswerChallengeRequest.class));
}
}
}

View File

@@ -17,43 +17,45 @@ import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
@Before
public void setUp() throws Exception {
super.setUp();
}
@Before
public void setUp() throws Exception {
super.setUp();
}
@After
public void tearDown() throws Exception {
super.tearDown();
}
@After
public void tearDown() throws Exception {
super.tearDown();
}
@Test
public void testValidate() {
final int maxCardinality = 10;
final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality);
@Test
public void testValidate() {
final int maxCardinality = 10;
final CardinalityRateLimiter rateLimiter =
new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), maxCardinality);
final String source = "+18005551234";
int validatedAttempts = 0;
int blockedAttempts = 0;
for (int i = 0; i < maxCardinality * 2; i++) {
try {
rateLimiter.validate(source, String.valueOf(i));
validatedAttempts++;
} catch (final RateLimitExceededException e) {
blockedAttempts++;
}
}
assertTrue(validatedAttempts >= maxCardinality);
assertTrue(blockedAttempts > 0);
final String secondSource = "+18005554321";
final String source = "+18005551234";
int validatedAttempts = 0;
int blockedAttempts = 0;
for (int i = 0; i < maxCardinality * 2; i++) {
try {
rateLimiter.validate(secondSource, "test");
rateLimiter.validate(source, String.valueOf(i), rateLimiter.getDefaultMaxCardinality());
validatedAttempts++;
} catch (final RateLimitExceededException e) {
fail("New source should not trigger a rate limit exception on first attempted validation");
blockedAttempts++;
}
}
assertTrue(validatedAttempts >= maxCardinality);
assertTrue(blockedAttempts > 0);
final String secondSource = "+18005554321";
try {
rateLimiter.validate(secondSource, "test", rateLimiter.getDefaultMaxCardinality());
} catch (final RateLimitExceededException e) {
fail("New source should not trigger a rate limit exception on first attempted validation");
}
}
}

View File

@@ -0,0 +1,66 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
class PreKeyRateLimiterTest {
private Account account;
private PreKeyRateLimiter preKeyRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private RateLimiter dailyPreKeyLimiter;
@BeforeEach
void setup() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
dailyPreKeyLimiter = mock(RateLimiter.class);
when(rateLimiters.getDailyPreKeysLimiter()).thenReturn(dailyPreKeyLimiter);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
preKeyRateLimiter = new PreKeyRateLimiter(rateLimiters, dynamicConfigurationManager, mock(RateLimitResetMetricsManager.class));
account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551111");
when(account.getUuid()).thenReturn(UUID.randomUUID());
}
@Test
void enforcementConfiguration() throws RateLimitExceededException {
doThrow(RateLimitExceededException.class)
.when(dailyPreKeyLimiter).validate(any());
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(false);
preKeyRateLimiter.validate(account);
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(true);
assertThrows(RateLimitExceededException.class, () -> preKeyRateLimiter.validate(account));
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(false);
preKeyRateLimiter.validate(account);
}
}

View File

@@ -0,0 +1,190 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import com.vdurmont.semver4j.Semver;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
class RateLimitChallengeManagerTest {
private PushChallengeManager pushChallengeManager;
private RecaptchaClient recaptchaClient;
private PreKeyRateLimiter preKeyRateLimiter;
private UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private RateLimiters rateLimiters;
private RateLimitChallengeManager rateLimitChallengeManager;
@BeforeEach
void setUp() {
pushChallengeManager = mock(PushChallengeManager.class);
recaptchaClient = mock(RecaptchaClient.class);
preKeyRateLimiter = mock(PreKeyRateLimiter.class);
unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
rateLimiters = mock(RateLimiters.class);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
rateLimitChallengeManager = new RateLimitChallengeManager(
pushChallengeManager,
recaptchaClient,
preKeyRateLimiter,
unsealedSenderRateLimiter,
rateLimiters,
dynamicConfigurationManager);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void answerPushChallenge(final boolean successfulChallenge) throws RateLimitExceededException {
final Account account = mock(Account.class);
when(pushChallengeManager.answerChallenge(eq(account), any())).thenReturn(successfulChallenge);
when(rateLimiters.getPushChallengeAttemptLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getPushChallengeSuccessLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
rateLimitChallengeManager.answerPushChallenge(account, "challenge");
if (successfulChallenge) {
verify(preKeyRateLimiter).handleRateLimitReset(account);
verify(unsealedSenderRateLimiter).handleRateLimitReset(account);
} else {
verifyZeroInteractions(preKeyRateLimiter);
verifyZeroInteractions(unsealedSenderRateLimiter);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void answerRecaptchaChallenge(final boolean successfulChallenge) throws RateLimitExceededException {
final Account account = mock(Account.class);
when(recaptchaClient.verify(any(), any())).thenReturn(successfulChallenge);
when(rateLimiters.getRecaptchaChallengeAttemptLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRecaptchaChallengeSuccessLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
rateLimitChallengeManager.answerRecaptchaChallenge(account, "captcha", "10.0.0.1");
if (successfulChallenge) {
verify(preKeyRateLimiter).handleRateLimitReset(account);
verify(unsealedSenderRateLimiter).handleRateLimitReset(account);
} else {
verifyZeroInteractions(preKeyRateLimiter);
verifyZeroInteractions(unsealedSenderRateLimiter);
}
}
@ParameterizedTest
@MethodSource
void shouldIssueRateLimitChallenge(final String userAgent, final boolean expectIssueChallenge) {
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(any())).thenReturn(Optional.empty());
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.ANDROID))
.thenReturn(Optional.of(new Semver("5.6.0")));
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.DESKTOP))
.thenReturn(Optional.of(new Semver("5.0.0-beta.2")));
assertEquals(expectIssueChallenge, rateLimitChallengeManager.shouldIssueRateLimitChallenge(userAgent));
}
private static Stream<Arguments> shouldIssueRateLimitChallenge() {
return Stream.of(
Arguments.of("Signal-Android/5.1.2 Android/30", false),
Arguments.of("Signal-Android/5.6.0 Android/30", true),
Arguments.of("Signal-Android/5.11.1 Android/30", true),
Arguments.of("Signal-Desktop/5.0.0-beta.3 macOS/11", true),
Arguments.of("Signal-Desktop/5.0.0-beta.1 Windows/3.1", false),
Arguments.of("Signal-Desktop/5.2.0 Debian/11", true),
Arguments.of("Signal-iOS/5.1.2 iOS/12.2", false),
Arguments.of("anything-else", false)
);
}
@ParameterizedTest
@MethodSource
void getChallengeOptions(final boolean captchaAttemptPermitted,
final boolean captchaSuccessPermitted,
final boolean pushAttemptPermitted,
final boolean pushSuccessPermitted,
final boolean expectCaptcha,
final boolean expectPushChallenge) {
final RateLimiter recaptchaChallengeAttemptLimiter = mock(RateLimiter.class);
final RateLimiter recaptchaChallengeSuccessLimiter = mock(RateLimiter.class);
final RateLimiter pushChallengeAttemptLimiter = mock(RateLimiter.class);
final RateLimiter pushChallengeSuccessLimiter = mock(RateLimiter.class);
when(rateLimiters.getRecaptchaChallengeAttemptLimiter()).thenReturn(recaptchaChallengeAttemptLimiter);
when(rateLimiters.getRecaptchaChallengeSuccessLimiter()).thenReturn(recaptchaChallengeSuccessLimiter);
when(rateLimiters.getPushChallengeAttemptLimiter()).thenReturn(pushChallengeAttemptLimiter);
when(rateLimiters.getPushChallengeSuccessLimiter()).thenReturn(pushChallengeSuccessLimiter);
when(recaptchaChallengeAttemptLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(captchaAttemptPermitted);
when(recaptchaChallengeSuccessLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(captchaSuccessPermitted);
when(pushChallengeAttemptLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(pushAttemptPermitted);
when(pushChallengeSuccessLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(pushSuccessPermitted);
final int expectedLength = (expectCaptcha ? 1 : 0) + (expectPushChallenge ? 1 : 0);
final List<String> options = rateLimitChallengeManager.getChallengeOptions(mock(Account.class));
assertEquals(expectedLength, options.size());
if (expectCaptcha) {
assertTrue(options.contains(RateLimitChallengeManager.OPTION_RECAPTCHA));
}
if (expectPushChallenge) {
assertTrue(options.contains(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE));
}
}
private static Stream<Arguments> getChallengeOptions() {
return Stream.of(
Arguments.of(false, false, false, false, false, false),
Arguments.of(false, false, false, true, false, false),
Arguments.of(false, false, true, false, false, false),
Arguments.of(false, false, true, true, false, true),
Arguments.of(false, true, false, false, false, false),
Arguments.of(false, true, false, true, false, false),
Arguments.of(false, true, true, false, false, false),
Arguments.of(false, true, true, true, false, true),
Arguments.of(true, false, false, false, false, false),
Arguments.of(true, false, false, true, false, false),
Arguments.of(true, false, true, false, false, false),
Arguments.of(true, false, true, true, false, true),
Arguments.of(true, true, false, false, true, false),
Arguments.of(true, true, false, true, true, false),
Arguments.of(true, true, true, false, true, false),
Arguments.of(true, true, true, true, true, true)
);
}
}

View File

@@ -0,0 +1,58 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.util.Duration;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.storage.Account;
public class RateLimitResetMetricsManagerTest extends AbstractRedisClusterTest {
private RateLimitResetMetricsManager metricsManager;
private SimpleMeterRegistry meterRegistry;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
meterRegistry = new SimpleMeterRegistry();
metricsManager = new RateLimitResetMetricsManager(getRedisCluster(), meterRegistry);
}
@Test
public void testRecordMetrics() {
final Account firstAccount = mock(Account.class);
when(firstAccount.getUuid()).thenReturn(UUID.randomUUID());
final Account secondAccount = mock(Account.class);
when(secondAccount.getUuid()).thenReturn(UUID.randomUUID());
metricsManager.recordMetrics(firstAccount, true, "counter", "enforced", "total", Duration.hours(1).toSeconds());
metricsManager.recordMetrics(firstAccount, true, "counter", "enforced", "total", Duration.hours(1).toSeconds());
metricsManager.recordMetrics(secondAccount, false, "counter", "unenforced", "total", Duration.hours(1).toSeconds());
final double counterTotal = meterRegistry.get("counter").counters().stream()
.map(Counter::count)
.reduce(Double::sum)
.orElseThrow();
assertEquals(3, counterTotal, 0.0);
final long enforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("enforced"));
assertEquals(1L, enforcedCount);
final long unenforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("unenforced"));
assertEquals(1L, unenforcedCount);
final long total = getRedisCluster().withCluster(conn -> conn.sync().pfcount("total"));
assertEquals(2L, total);
}
}

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest {
private Account sender;
private Account firstDestination;
private Account secondDestination;
private UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
final RateLimiters rateLimiters = mock(RateLimiters.class);
final CardinalityRateLimiter cardinalityRateLimiter =
new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), 1);
when(rateLimiters.getUnsealedSenderCardinalityLimiter()).thenReturn(cardinalityRateLimiter);
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicRateLimitsConfiguration rateLimitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getLimits()).thenReturn(rateLimitsConfiguration);
when(rateLimitsConfiguration.getUnsealedSenderDefaultCardinalityLimit()).thenReturn(1);
when(rateLimitsConfiguration.getUnsealedSenderPermitIncrement()).thenReturn(1);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(true);
unsealedSenderRateLimiter = new UnsealedSenderRateLimiter(rateLimiters, getRedisCluster(), dynamicConfigurationManager,
mock(RateLimitResetMetricsManager.class));
sender = mock(Account.class);
when(sender.getNumber()).thenReturn("+18005551111");
when(sender.getUuid()).thenReturn(UUID.randomUUID());
firstDestination = mock(Account.class);
when(firstDestination.getNumber()).thenReturn("+18005552222");
when(firstDestination.getUuid()).thenReturn(UUID.randomUUID());
secondDestination = mock(Account.class);
when(secondDestination.getNumber()).thenReturn("+18005553333");
when(secondDestination.getUuid()).thenReturn(UUID.randomUUID());
}
@Test
public void validate() throws RateLimitExceededException {
unsealedSenderRateLimiter.validate(sender, firstDestination);
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination));
unsealedSenderRateLimiter.validate(sender, firstDestination);
}
@Test
public void handleRateLimitReset() throws RateLimitExceededException {
unsealedSenderRateLimiter.validate(sender, firstDestination);
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination));
unsealedSenderRateLimiter.handleRateLimitReset(sender);
unsealedSenderRateLimiter.validate(sender, firstDestination);
unsealedSenderRateLimiter.validate(sender, secondDestination);
}
@Test
public void enforcementConfiguration() throws RateLimitExceededException {
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(false);
unsealedSenderRateLimiter.validate(sender, firstDestination);
unsealedSenderRateLimiter.validate(sender, secondDestination);
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(true);
final Account thirdDestination = mock(Account.class);
when(thirdDestination.getNumber()).thenReturn("+18005554444");
when(thirdDestination.getUuid()).thenReturn(UUID.randomUUID());
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, thirdDestination));
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(false);
final Account fourthDestination = mock(Account.class);
when(fourthDestination.getNumber()).thenReturn("+18005555555");
when(fourthDestination.getUuid()).thenReturn(UUID.randomUUID());
unsealedSenderRateLimiter.validate(sender, fourthDestination);
}
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.amazonaws.services.dynamodbv2.model.AttributeDefinition;
import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Random;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
class PushChallengeDynamoDbTest {
private PushChallengeDynamoDb pushChallengeDynamoDb;
private static final long CURRENT_TIME_MILLIS = 1_000_000_000;
private static final Random RANDOM = new Random();
private static final String TABLE_NAME = "push_challenge_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName(TABLE_NAME)
.hashKey(PushChallengeDynamoDb.KEY_ACCOUNT_UUID)
.attributeDefinition(new AttributeDefinition(PushChallengeDynamoDb.KEY_ACCOUNT_UUID, ScalarAttributeType.B))
.build();
@BeforeEach
void setUp() {
this.pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbExtension.getDynamoDB(), TABLE_NAME, Clock.fixed(
Instant.ofEpochMilli(CURRENT_TIME_MILLIS), ZoneId.systemDefault()));
}
@Test
void add() {
final UUID uuid = UUID.randomUUID();
assertTrue(pushChallengeDynamoDb.add(uuid, generateRandomToken(), Duration.ofMinutes(1)));
assertFalse(pushChallengeDynamoDb.add(uuid, generateRandomToken(), Duration.ofMinutes(1)));
}
@Test
void remove() {
final UUID uuid = UUID.randomUUID();
final byte[] token = generateRandomToken();
assertFalse(pushChallengeDynamoDb.remove(uuid, token));
assertTrue(pushChallengeDynamoDb.add(uuid, token, Duration.ofMinutes(1)));
assertTrue(pushChallengeDynamoDb.remove(uuid, token));
}
@Test
void getExpirationTimestamp() {
assertEquals((CURRENT_TIME_MILLIS / 1000) + 3600,
pushChallengeDynamoDb.getExpirationTimestamp(Duration.ofHours(1)));
}
private static byte[] generateRandomToken() {
final byte[] token = new byte[16];
RANDOM.nextBytes(token);
return token;
}
}

View File

@@ -5,35 +5,23 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.collect.ImmutableSet;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@@ -41,13 +29,42 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.PreKeyRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
public class KeysControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
private static final String EXISTS_NUMBER = "+14152222222";
private static final UUID EXISTS_UUID = UUID.randomUUID();
@@ -70,24 +87,28 @@ public class KeysControllerTest {
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final AccountsManager accounts = mock(AccountsManager.class );
private final DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final Account existsAccount = mock(Account.class );
private final static KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final static PreKeyRateLimiter preKeyRateLimiter = mock(PreKeyRateLimiter.class );
private final static RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class );
private final static Account existsAccount = mock(Account.class );
private RateLimiters rateLimiters = mock(RateLimiters.class);
private RateLimiter rateLimiter = mock(RateLimiter.class );
private final static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class );
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue))
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager))
.build();
@Before
public void setup() {
@BeforeEach
void setup() {
final Device sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
@@ -153,8 +174,23 @@ public class KeysControllerTest {
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@AfterEach
void teardown() {
reset(
keysDynamoDb,
accounts,
directoryQueue,
preKeyRateLimiter,
existsAccount,
rateLimiters,
rateLimiter,
dynamicConfigurationManager,
rateLimitChallengeManager
);
}
@Test
public void validKeyStatusTestByNumberV2() throws Exception {
void validKeyStatusTestByNumberV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@@ -168,7 +204,7 @@ public class KeysControllerTest {
}
@Test
public void validKeyStatusTestByUuidV2() throws Exception {
void validKeyStatusTestByUuidV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@@ -183,7 +219,7 @@ public class KeysControllerTest {
@Test
public void getSignedPreKeyV2ByNumber() throws Exception {
void getSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -196,7 +232,7 @@ public class KeysControllerTest {
}
@Test
public void getSignedPreKeyV2ByUuid() throws Exception {
void getSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -209,7 +245,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByNumber() throws Exception {
void putSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -224,7 +260,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByUuid() throws Exception {
void putSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9998, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -240,7 +276,7 @@ public class KeysControllerTest {
@Test
public void disabledPutSignedPreKeyV2ByNumber() throws Exception {
void disabledPutSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -252,7 +288,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutSignedPreKeyV2ByUuid() throws Exception {
void disabledPutSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -265,7 +301,7 @@ public class KeysControllerTest {
@Test
public void validSingleRequestTestV2ByNumber() throws Exception {
void validSingleRequestTestV2ByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -283,7 +319,7 @@ public class KeysControllerTest {
}
@Test
public void validSingleRequestTestV2ByUuid() throws Exception {
void validSingleRequestTestV2ByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
@@ -302,7 +338,7 @@ public class KeysControllerTest {
@Test
public void testUnidentifiedRequestByNumber() throws Exception {
void testUnidentifiedRequestByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -320,7 +356,7 @@ public class KeysControllerTest {
}
@Test
public void testUnidentifiedRequestByUuid() throws Exception {
void testUnidentifiedRequestByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID.toString()))
.request()
@@ -337,9 +373,23 @@ public class KeysControllerTest {
verifyNoMoreInteractions(keysDynamoDb);
}
@Test
void testNoDevices() {
when(existsAccount.getDevices()).thenReturn(Collections.emptySet());
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
assertThat(result).isNotNull();
assertThat(result.getStatus()).isEqualTo(404);
}
@Test
public void testUnauthorizedUnidentifiedRequest() throws Exception {
void testUnauthorizedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -351,7 +401,7 @@ public class KeysControllerTest {
}
@Test
public void testMalformedUnidentifiedRequest() throws Exception {
void testMalformedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -364,7 +414,7 @@ public class KeysControllerTest {
@Test
public void validMultiRequestTestV2ByNumber() throws Exception {
void validMultiRequestTestV2ByNumber() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_NUMBER))
.request()
@@ -414,7 +464,7 @@ public class KeysControllerTest {
}
@Test
public void validMultiRequestTestV2ByUuid() throws Exception {
void validMultiRequestTestV2ByUuid() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
@@ -465,7 +515,7 @@ public class KeysControllerTest {
@Test
public void invalidRequestTestV2() throws Exception {
void invalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s", NOT_EXISTS_NUMBER))
.request()
@@ -476,7 +526,7 @@ public class KeysControllerTest {
}
@Test
public void anotherInvalidRequestTestV2() throws Exception {
void anotherInvalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/22", EXISTS_NUMBER))
.request()
@@ -487,7 +537,7 @@ public class KeysControllerTest {
}
@Test
public void unauthorizedRequestTestV2() throws Exception {
void unauthorizedRequestTestV2() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
@@ -507,7 +557,7 @@ public class KeysControllerTest {
}
@Test
public void putKeysTestV2() throws Exception {
void putKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@@ -541,7 +591,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutKeysTestV2() throws Exception {
void disabledPutKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@@ -574,5 +624,42 @@ public class KeysControllerTest {
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
}
@Test
void testRateLimitChallenge() throws RateLimitExceededException {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(preKeyRateLimiter).validate(any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
// unidentified access should not be rate limited
assertThat(result.getStatus()).isEqualTo(200);
result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
assertThat(result.getStatus()).isEqualTo(428);
RateLimitChallenge rateLimitChallenge = result.readEntity(RateLimitChallenge.class);
assertThat(rateLimitChallenge.getToken()).isNotBlank();
assertThat(rateLimitChallenge.getOptions()).isNotEmpty();
assertThat(rateLimitChallenge.getOptions()).contains("recaptcha");
assertThat(rateLimitChallenge.getOptions()).contains("pushChallenge");
assertThat(Long.parseLong(result.getHeaderString("Retry-After"))).isEqualTo(retryAfter.toSeconds());
}
}

View File

@@ -31,6 +31,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import com.vdurmont.semver4j.Semver;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
@@ -57,8 +58,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
@@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
@@ -74,11 +76,15 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
@@ -91,6 +97,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@@ -104,6 +111,8 @@ class MessageControllerTest {
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
private Account internationalAccount;
@SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
@@ -114,8 +123,10 @@ class MessageControllerTest {
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class);
private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands);
private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class);
@@ -125,9 +136,10 @@ class MessageControllerTest {
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, apnFallbackManager, dynamicConfigurationManager, metricsCluster, receiptExecutor))
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, metricsCluster, receiptExecutor))
.build();
@BeforeEach
@@ -148,7 +160,7 @@ class MessageControllerTest {
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, multiDeviceList, "1234".getBytes());
Account internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(SINGLE_DEVICE_RECIPIENT)))).thenReturn(Optional.of(singleDeviceAccount));
@@ -158,7 +170,6 @@ class MessageControllerTest {
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(INTERNATIONAL_RECIPIENT)))).thenReturn(Optional.of(internationalAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getUnsealedSenderLimiter()).thenReturn(unsealedSenderLimiter);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
@@ -179,9 +190,10 @@ class MessageControllerTest {
messagesManager,
rateLimiters,
rateLimiter,
unsealedSenderLimiter,
unsealedSenderRateLimiter,
apnFallbackManager,
dynamicConfigurationManager,
rateLimitChallengeManager,
metricsCluster,
receiptExecutor
);
@@ -254,8 +266,8 @@ class MessageControllerTest {
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited) throws Exception {
@CsvSource({"true, 5.1.0, 413", "true, 5.6.4, 428", "false, 5.6.4, 200"})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited, final String clientVersion, final int expectedStatusCode) throws Exception {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final DynamicMessageRateConfiguration messageRateConfiguration = mock(DynamicMessageRateConfiguration.class);
@@ -268,11 +280,23 @@ class MessageControllerTest {
when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0);
DynamicRateLimitChallengeConfiguration dynamicRateLimitChallengeConfiguration = mock(
DynamicRateLimitChallengeConfiguration.class);
when(dynamicConfiguration.getRateLimitChallengeConfiguration())
.thenReturn(dynamicRateLimitChallengeConfiguration);
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(any())).thenReturn(Optional.empty());
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.ANDROID))
.thenReturn(Optional.of(new Semver("5.5.0")));
when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L));
if (rateLimited) {
doThrow(RateLimitExceededException.class)
.when(unsealedSenderLimiter).validate(eq(AuthHelper.VALID_NUMBER), eq(INTERNATIONAL_RECIPIENT));
doThrow(new RateLimitExceededException(Duration.ofHours(1)))
.when(unsealedSenderRateLimiter).validate(eq(AuthHelper.VALID_ACCOUNT), eq(internationalAccount));
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge(String.format("Signal-Android/%s Android/30", clientVersion)))
.thenReturn(true);
}
Response response =
@@ -280,18 +304,50 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.6.4 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
if (rateLimited) {
assertThat("Error Response", response.getStatus(), is(equalTo(413)));
assertThat("Error Response", response.getStatus(), is(equalTo(expectedStatusCode)));
} else {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
assertThat("Good Response", response.getStatus(), is(equalTo(expectedStatusCode)));
}
verify(messageSender, rateLimited ? never() : times(1)).sendMessage(any(), any(), any(), anyBoolean());
}
@Test
void testRateLimitResetRequirement() throws Exception {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(unsealedSenderRateLimiter).validate(any(), any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertEquals(428, response.getStatus());
RateLimitChallenge rateLimitChallenge = response.readEntity(RateLimitChallenge.class);
assertFalse(rateLimitChallenge.getToken().isBlank());
assertFalse(rateLimitChallenge.getOptions().isEmpty());
assertTrue(rateLimitChallenge.getOptions().contains("recaptcha"));
assertTrue(rateLimitChallenge.getOptions().contains("pushChallenge"));
assertEquals(retryAfter.toSeconds(), Long.parseLong(response.getHeaderString("Retry-After")));
}
@Test
void testSingleDeviceCurrentUnidentified() throws Exception {
Response response =

View File

@@ -1,15 +1,16 @@
package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
@@ -18,13 +19,13 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimitsTest {
class DynamicRateLimitsTest {
private DynamicConfigurationManager dynamicConfig;
private FaultTolerantRedisCluster redisCluster;
@Before
public void setup() {
@BeforeEach
void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class);
@@ -34,7 +35,7 @@ public class DynamicRateLimitsTest {
}
@Test
public void testUnchangingConfiguration() {
void testUnchangingConfiguration() {
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
RateLimiter limiter = rateLimiters.getUnsealedIpLimiter();
@@ -45,34 +46,39 @@ public class DynamicRateLimitsTest {
}
@Test
public void testChangingConfiguration() {
void testChangingConfiguration() {
DynamicConfiguration configuration = mock(DynamicConfiguration.class);
DynamicRateLimitsConfiguration limitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
when(configuration.getLimits()).thenReturn(limitsConfiguration);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1), Duration.ofMinutes(10)));
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(new RateLimitsConfiguration.RateLimitConfiguration(4, 1.0));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1)));
when(limitsConfiguration.getRecaptchaChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getRecaptchaChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getDailyPreKeys()).thenReturn(new RateLimitConfiguration());
final RateLimitConfiguration initialRateLimitConfiguration = new RateLimitConfiguration(4, 1.0);
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(initialRateLimitConfiguration);
when(limitsConfiguration.getRateLimitReset()).thenReturn(initialRateLimitConfiguration);
when(dynamicConfig.getConfiguration()).thenReturn(configuration);
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(limiter.getMaxCardinality()).isEqualTo(10);
assertThat(limiter.getTtl()).isEqualTo(Duration.ofHours(1));
assertThat(limiter.getTtlJitter()).isEqualTo(Duration.ofMinutes(10));
assertSame(rateLimiters.getUnsealedSenderLimiter(), limiter);
assertThat(limiter.getDefaultMaxCardinality()).isEqualTo(10);
assertThat(limiter.getInitialTtl()).isEqualTo(Duration.ofHours(1));
assertSame(rateLimiters.getUnsealedSenderCardinalityLimiter(), limiter);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2), Duration.ofMinutes(7)));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2)));
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(changed.getMaxCardinality()).isEqualTo(20);
assertThat(changed.getTtl()).isEqualTo(Duration.ofHours(2));
assertThat(changed.getTtlJitter()).isEqualTo(Duration.ofMinutes(7));
assertThat(changed.getDefaultMaxCardinality()).isEqualTo(20);
assertThat(changed.getInitialTtl()).isEqualTo(Duration.ofHours(2));
assertNotSame(limiter, changed);
}
}

View File

@@ -19,6 +19,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -65,7 +66,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -99,7 +100,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -135,7 +136,7 @@ public class APNSenderTest {
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -238,7 +239,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -333,7 +334,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -366,7 +367,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), new Exception("lost connection")));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);