Add client challenges for prekey and message rate limiters

This commit is contained in:
Jon Chambers
2021-05-11 17:21:32 -04:00
committed by GitHub
parent 5752853bba
commit 46110d4d65
46 changed files with 2289 additions and 255 deletions

View File

@@ -5,35 +5,23 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.collect.ImmutableSet;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@@ -41,13 +29,42 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.PreKeyRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
public class KeysControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
private static final String EXISTS_NUMBER = "+14152222222";
private static final UUID EXISTS_UUID = UUID.randomUUID();
@@ -70,24 +87,28 @@ public class KeysControllerTest {
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final AccountsManager accounts = mock(AccountsManager.class );
private final DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final Account existsAccount = mock(Account.class );
private final static KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final static PreKeyRateLimiter preKeyRateLimiter = mock(PreKeyRateLimiter.class );
private final static RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class );
private final static Account existsAccount = mock(Account.class );
private RateLimiters rateLimiters = mock(RateLimiters.class);
private RateLimiter rateLimiter = mock(RateLimiter.class );
private final static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class );
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue))
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager))
.build();
@Before
public void setup() {
@BeforeEach
void setup() {
final Device sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
@@ -153,8 +174,23 @@ public class KeysControllerTest {
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@AfterEach
void teardown() {
reset(
keysDynamoDb,
accounts,
directoryQueue,
preKeyRateLimiter,
existsAccount,
rateLimiters,
rateLimiter,
dynamicConfigurationManager,
rateLimitChallengeManager
);
}
@Test
public void validKeyStatusTestByNumberV2() throws Exception {
void validKeyStatusTestByNumberV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@@ -168,7 +204,7 @@ public class KeysControllerTest {
}
@Test
public void validKeyStatusTestByUuidV2() throws Exception {
void validKeyStatusTestByUuidV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@@ -183,7 +219,7 @@ public class KeysControllerTest {
@Test
public void getSignedPreKeyV2ByNumber() throws Exception {
void getSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -196,7 +232,7 @@ public class KeysControllerTest {
}
@Test
public void getSignedPreKeyV2ByUuid() throws Exception {
void getSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@@ -209,7 +245,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByNumber() throws Exception {
void putSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -224,7 +260,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByUuid() throws Exception {
void putSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9998, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -240,7 +276,7 @@ public class KeysControllerTest {
@Test
public void disabledPutSignedPreKeyV2ByNumber() throws Exception {
void disabledPutSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -252,7 +288,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutSignedPreKeyV2ByUuid() throws Exception {
void disabledPutSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@@ -265,7 +301,7 @@ public class KeysControllerTest {
@Test
public void validSingleRequestTestV2ByNumber() throws Exception {
void validSingleRequestTestV2ByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -283,7 +319,7 @@ public class KeysControllerTest {
}
@Test
public void validSingleRequestTestV2ByUuid() throws Exception {
void validSingleRequestTestV2ByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
@@ -302,7 +338,7 @@ public class KeysControllerTest {
@Test
public void testUnidentifiedRequestByNumber() throws Exception {
void testUnidentifiedRequestByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -320,7 +356,7 @@ public class KeysControllerTest {
}
@Test
public void testUnidentifiedRequestByUuid() throws Exception {
void testUnidentifiedRequestByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID.toString()))
.request()
@@ -337,9 +373,23 @@ public class KeysControllerTest {
verifyNoMoreInteractions(keysDynamoDb);
}
@Test
void testNoDevices() {
when(existsAccount.getDevices()).thenReturn(Collections.emptySet());
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
assertThat(result).isNotNull();
assertThat(result.getStatus()).isEqualTo(404);
}
@Test
public void testUnauthorizedUnidentifiedRequest() throws Exception {
void testUnauthorizedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -351,7 +401,7 @@ public class KeysControllerTest {
}
@Test
public void testMalformedUnidentifiedRequest() throws Exception {
void testMalformedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@@ -364,7 +414,7 @@ public class KeysControllerTest {
@Test
public void validMultiRequestTestV2ByNumber() throws Exception {
void validMultiRequestTestV2ByNumber() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_NUMBER))
.request()
@@ -414,7 +464,7 @@ public class KeysControllerTest {
}
@Test
public void validMultiRequestTestV2ByUuid() throws Exception {
void validMultiRequestTestV2ByUuid() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
@@ -465,7 +515,7 @@ public class KeysControllerTest {
@Test
public void invalidRequestTestV2() throws Exception {
void invalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s", NOT_EXISTS_NUMBER))
.request()
@@ -476,7 +526,7 @@ public class KeysControllerTest {
}
@Test
public void anotherInvalidRequestTestV2() throws Exception {
void anotherInvalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/22", EXISTS_NUMBER))
.request()
@@ -487,7 +537,7 @@ public class KeysControllerTest {
}
@Test
public void unauthorizedRequestTestV2() throws Exception {
void unauthorizedRequestTestV2() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
@@ -507,7 +557,7 @@ public class KeysControllerTest {
}
@Test
public void putKeysTestV2() throws Exception {
void putKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@@ -541,7 +591,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutKeysTestV2() throws Exception {
void disabledPutKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@@ -574,5 +624,42 @@ public class KeysControllerTest {
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
}
@Test
void testRateLimitChallenge() throws RateLimitExceededException {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(preKeyRateLimiter).validate(any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
// unidentified access should not be rate limited
assertThat(result.getStatus()).isEqualTo(200);
result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
assertThat(result.getStatus()).isEqualTo(428);
RateLimitChallenge rateLimitChallenge = result.readEntity(RateLimitChallenge.class);
assertThat(rateLimitChallenge.getToken()).isNotBlank();
assertThat(rateLimitChallenge.getOptions()).isNotEmpty();
assertThat(rateLimitChallenge.getOptions()).contains("recaptcha");
assertThat(rateLimitChallenge.getOptions()).contains("pushChallenge");
assertThat(Long.parseLong(result.getHeaderString("Retry-After"))).isEqualTo(retryAfter.toSeconds());
}
}

View File

@@ -31,6 +31,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import com.vdurmont.semver4j.Semver;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
@@ -57,8 +58,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
@@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
@@ -74,11 +76,15 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
@@ -91,6 +97,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@@ -104,6 +111,8 @@ class MessageControllerTest {
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
private Account internationalAccount;
@SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
@@ -114,8 +123,10 @@ class MessageControllerTest {
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class);
private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands);
private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class);
@@ -125,9 +136,10 @@ class MessageControllerTest {
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, apnFallbackManager, dynamicConfigurationManager, metricsCluster, receiptExecutor))
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, metricsCluster, receiptExecutor))
.build();
@BeforeEach
@@ -148,7 +160,7 @@ class MessageControllerTest {
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, multiDeviceList, "1234".getBytes());
Account internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(SINGLE_DEVICE_RECIPIENT)))).thenReturn(Optional.of(singleDeviceAccount));
@@ -158,7 +170,6 @@ class MessageControllerTest {
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(INTERNATIONAL_RECIPIENT)))).thenReturn(Optional.of(internationalAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getUnsealedSenderLimiter()).thenReturn(unsealedSenderLimiter);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
@@ -179,9 +190,10 @@ class MessageControllerTest {
messagesManager,
rateLimiters,
rateLimiter,
unsealedSenderLimiter,
unsealedSenderRateLimiter,
apnFallbackManager,
dynamicConfigurationManager,
rateLimitChallengeManager,
metricsCluster,
receiptExecutor
);
@@ -254,8 +266,8 @@ class MessageControllerTest {
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited) throws Exception {
@CsvSource({"true, 5.1.0, 413", "true, 5.6.4, 428", "false, 5.6.4, 200"})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited, final String clientVersion, final int expectedStatusCode) throws Exception {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final DynamicMessageRateConfiguration messageRateConfiguration = mock(DynamicMessageRateConfiguration.class);
@@ -268,11 +280,23 @@ class MessageControllerTest {
when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0);
DynamicRateLimitChallengeConfiguration dynamicRateLimitChallengeConfiguration = mock(
DynamicRateLimitChallengeConfiguration.class);
when(dynamicConfiguration.getRateLimitChallengeConfiguration())
.thenReturn(dynamicRateLimitChallengeConfiguration);
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(any())).thenReturn(Optional.empty());
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.ANDROID))
.thenReturn(Optional.of(new Semver("5.5.0")));
when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L));
if (rateLimited) {
doThrow(RateLimitExceededException.class)
.when(unsealedSenderLimiter).validate(eq(AuthHelper.VALID_NUMBER), eq(INTERNATIONAL_RECIPIENT));
doThrow(new RateLimitExceededException(Duration.ofHours(1)))
.when(unsealedSenderRateLimiter).validate(eq(AuthHelper.VALID_ACCOUNT), eq(internationalAccount));
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge(String.format("Signal-Android/%s Android/30", clientVersion)))
.thenReturn(true);
}
Response response =
@@ -280,18 +304,50 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.6.4 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
if (rateLimited) {
assertThat("Error Response", response.getStatus(), is(equalTo(413)));
assertThat("Error Response", response.getStatus(), is(equalTo(expectedStatusCode)));
} else {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
assertThat("Good Response", response.getStatus(), is(equalTo(expectedStatusCode)));
}
verify(messageSender, rateLimited ? never() : times(1)).sendMessage(any(), any(), any(), anyBoolean());
}
@Test
void testRateLimitResetRequirement() throws Exception {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(unsealedSenderRateLimiter).validate(any(), any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertEquals(428, response.getStatus());
RateLimitChallenge rateLimitChallenge = response.readEntity(RateLimitChallenge.class);
assertFalse(rateLimitChallenge.getToken().isBlank());
assertFalse(rateLimitChallenge.getOptions().isEmpty());
assertTrue(rateLimitChallenge.getOptions().contains("recaptcha"));
assertTrue(rateLimitChallenge.getOptions().contains("pushChallenge"));
assertEquals(retryAfter.toSeconds(), Long.parseLong(response.getHeaderString("Retry-After")));
}
@Test
void testSingleDeviceCurrentUnidentified() throws Exception {
Response response =

View File

@@ -1,15 +1,16 @@
package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
@@ -18,13 +19,13 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimitsTest {
class DynamicRateLimitsTest {
private DynamicConfigurationManager dynamicConfig;
private FaultTolerantRedisCluster redisCluster;
@Before
public void setup() {
@BeforeEach
void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class);
@@ -34,7 +35,7 @@ public class DynamicRateLimitsTest {
}
@Test
public void testUnchangingConfiguration() {
void testUnchangingConfiguration() {
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
RateLimiter limiter = rateLimiters.getUnsealedIpLimiter();
@@ -45,34 +46,39 @@ public class DynamicRateLimitsTest {
}
@Test
public void testChangingConfiguration() {
void testChangingConfiguration() {
DynamicConfiguration configuration = mock(DynamicConfiguration.class);
DynamicRateLimitsConfiguration limitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
when(configuration.getLimits()).thenReturn(limitsConfiguration);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1), Duration.ofMinutes(10)));
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(new RateLimitsConfiguration.RateLimitConfiguration(4, 1.0));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1)));
when(limitsConfiguration.getRecaptchaChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getRecaptchaChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getDailyPreKeys()).thenReturn(new RateLimitConfiguration());
final RateLimitConfiguration initialRateLimitConfiguration = new RateLimitConfiguration(4, 1.0);
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(initialRateLimitConfiguration);
when(limitsConfiguration.getRateLimitReset()).thenReturn(initialRateLimitConfiguration);
when(dynamicConfig.getConfiguration()).thenReturn(configuration);
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(limiter.getMaxCardinality()).isEqualTo(10);
assertThat(limiter.getTtl()).isEqualTo(Duration.ofHours(1));
assertThat(limiter.getTtlJitter()).isEqualTo(Duration.ofMinutes(10));
assertSame(rateLimiters.getUnsealedSenderLimiter(), limiter);
assertThat(limiter.getDefaultMaxCardinality()).isEqualTo(10);
assertThat(limiter.getInitialTtl()).isEqualTo(Duration.ofHours(1));
assertSame(rateLimiters.getUnsealedSenderCardinalityLimiter(), limiter);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2), Duration.ofMinutes(7)));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2)));
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(changed.getMaxCardinality()).isEqualTo(20);
assertThat(changed.getTtl()).isEqualTo(Duration.ofHours(2));
assertThat(changed.getTtlJitter()).isEqualTo(Duration.ofMinutes(7));
assertThat(changed.getDefaultMaxCardinality()).isEqualTo(20);
assertThat(changed.getInitialTtl()).isEqualTo(Duration.ofHours(2));
assertNotSame(limiter, changed);
}
}

View File

@@ -19,6 +19,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -65,7 +66,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -99,7 +100,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -135,7 +136,7 @@ public class APNSenderTest {
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -238,7 +239,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -333,7 +334,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@@ -366,7 +367,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), new Exception("lost connection")));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);