Rate limiters code refactored

This commit is contained in:
Sergey Skrobotov
2023-02-23 10:21:39 -08:00
parent 378b32d44d
commit 7529c35013
35 changed files with 738 additions and 774 deletions

View File

@@ -0,0 +1,84 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
class LeakyBucketTest {
@Test
void testFull() {
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(1));
assertTrue(leakyBucket.add(1));
assertFalse(leakyBucket.add(1));
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(2));
assertFalse(leakyBucket.add(1));
assertFalse(leakyBucket.add(2));
}
@Test
void testLapseRate() throws IOException {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertTrue(leakyBucket.add(1));
String serializedAgain = leakyBucket.serialize(mapper);
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
assertFalse(leakyBucketAgain.add(1));
}
@Test
void testLapseShort() throws Exception {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertFalse(leakyBucket.add(1));
}
@Test
void testGetTimeUntilSpaceAvailable() throws Exception {
ObjectMapper mapper = new ObjectMapper();
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
}
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
}
}
}

View File

@@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.mockito.ArgumentMatchers.any;
@@ -12,17 +17,17 @@ import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.AssessmentResult;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account;
class RateLimitChallengeManagerTest {
private PushChallengeManager pushChallengeManager;
private CaptchaChecker captchaChecker;
private DynamicRateLimiters rateLimiters;
private RateLimiters rateLimiters;
private RateLimitChallengeListener rateLimitChallengeListener;
private RateLimitChallengeManager rateLimitChallengeManager;
@@ -31,7 +36,7 @@ class RateLimitChallengeManagerTest {
void setUp() {
pushChallengeManager = mock(PushChallengeManager.class);
captchaChecker = mock(CaptchaChecker.class);
rateLimiters = mock(DynamicRateLimiters.class);
rateLimiters = mock(RateLimiters.class);
rateLimitChallengeListener = mock(RateLimitChallengeListener.class);
rateLimitChallengeManager = new RateLimitChallengeManager(

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -30,13 +30,13 @@ import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
class RateLimitChallengeOptionManagerTest {
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private DynamicRateLimiters rateLimiters;
private RateLimiters rateLimiters;
private RateLimitChallengeOptionManager rateLimitChallengeOptionManager;
@BeforeEach
void setUp() {
rateLimiters = mock(DynamicRateLimiters.class);
rateLimiters = mock(RateLimiters.class);
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);

View File

@@ -11,7 +11,6 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Optional;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
@@ -43,14 +42,14 @@ public class RateLimitedByIpTest {
public static class Controller {
@GET
@Path("/strict")
@RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK)
@RateLimitedByIp(RateLimiters.For.BACKUP_AUTH_CHECK)
public Response strict() {
return Response.ok().build();
}
@GET
@Path("/loose")
@RateLimitedByIp(value = RateLimiters.Handle.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false)
@RateLimitedByIp(value = RateLimiters.For.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false)
public Response loose() {
return Response.ok().build();
}
@@ -59,7 +58,7 @@ public class RateLimitedByIpTest {
private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class);
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl ->
Mockito.when(rl.byHandle(Mockito.eq(RateLimiters.Handle.BACKUP_AUTH_CHECK))).thenReturn(Optional.of(RATE_LIMITER)));
Mockito.when(rl.forDescriptor(Mockito.eq(RateLimiters.For.BACKUP_AUTH_CHECK))).thenReturn(RATE_LIMITER));
private static final ResourceExtension RESOURCES = ResourceExtension.builder()
.setMapper(SystemMapper.getMapper())

View File

@@ -0,0 +1,176 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
unexpected:
bucketSize: 4
leakRatePerMinute: 2
""";
private static final String GOOD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
attachmentCreate:
bucketSize: 4
leakRatePerMinute: 2
""";
public record GenericHolder(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) {
}
@Test
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
}
@Test
public void testValidateDuplicates() throws Exception {
final TestDescriptor td1 = new TestDescriptor("id1");
final TestDescriptor td2 = new TestDescriptor("id2");
final TestDescriptor td3 = new TestDescriptor("id3");
final TestDescriptor tdDup = new TestDescriptor("id1");
assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {};
}
@Test
void testUnchangingConfiguration() {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
}
@Test
void testChangingConfiguration() {
final RateLimiterConfig initialRateLimiterConfig = new RateLimiterConfig(4, 1);
final RateLimiterConfig updatedRateLimiterCongig = new RateLimiterConfig(17, 19);
final RateLimiterConfig baseConfig = new RateLimiterConfig(1, 1);
final Map<String, RateLimiterConfig> limitsConfigMap = new HashMap<>();
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT.id(), baseConfig);
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_SUCCESS.id(), baseConfig);
when(configuration.getLimits()).thenReturn(limitsConfigMap);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
assertEquals(initialRateLimiterConfig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
assertEquals(updatedRateLimiterCongig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
}
@Test
public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT;
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, 1);
final RateLimiterConfig configForStatic = new RateLimiterConfig(2, 2);
final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
final Map<String, RateLimiterConfig> mapForStatic = new HashMap<>();
when(configuration.getLimits()).thenReturn(mapForDynamic);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present
mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id());
assertEquals(defaultConfig, limiter.config());
// test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id());
assertEquals(configForDynamic, limiter.config());
// test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForDynamic, limiter.config());
// test static, but no dynamic
mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForStatic, limiter.config());
}
private record TestDescriptor(String id) implements RateLimiterDescriptor {
@Override
public boolean isDynamic() {
return false;
}
@Override
public RateLimiterConfig defaultConfig() {
return new RateLimiterConfig(1, 1);
}
}
}