Moving RateLimiter logic to Redis Lua and adding async API

This commit is contained in:
Sergey Skrobotov
2023-03-06 13:45:35 -08:00
parent 46fef4082c
commit 4c85e7ba66
17 changed files with 723 additions and 302 deletions

View File

@@ -1,85 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.SystemMapper;
class LeakyBucketTest {
@Test
void testFull() {
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(1));
assertTrue(leakyBucket.add(1));
assertFalse(leakyBucket.add(1));
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(2));
assertFalse(leakyBucket.add(1));
assertFalse(leakyBucket.add(2));
}
@Test
void testLapseRate() throws IOException {
ObjectMapper mapper = SystemMapper.jsonMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertTrue(leakyBucket.add(1));
String serializedAgain = leakyBucket.serialize(mapper);
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
assertFalse(leakyBucketAgain.add(1));
}
@Test
void testLapseShort() throws Exception {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertFalse(leakyBucket.add(1));
}
@Test
void testGetTimeUntilSpaceAvailable() throws Exception {
ObjectMapper mapper = new ObjectMapper();
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
}
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
}
}
}

View File

@@ -0,0 +1,147 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.lettuce.core.ScriptOutputType;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox;
import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler;
public class RateLimitersLuaScriptTest {
@RegisterExtension
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
private final MutableClock clock = MockUtils.mutableClock(0);
private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource(
"lua/validate_rate_limit.lua",
ScriptOutputType.INTEGER);
private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock);
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
@Test
public void testWithEmbeddedRedis() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
dynamicConfig,
RateLimiters.defaultScript(redisCluster),
redisCluster,
Clock.systemUTC());
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
rateLimiter.validate("test", 25);
rateLimiter.validate("test", 25);
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
}
@Test
public void testLuaBucketConfigurationUpdates() throws Exception {
final String key = "key1";
clock.setTimeMillis(0);
long result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 1, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
// now making a check-only call, but changing the bucket size
result = (long) sandbox.execute(
List.of(key),
scriptArgs(2000, 1, 1, false),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
}
@Test
public void testLuaUpdatesTokenBucket() throws Exception {
final String key = "key1";
clock.setTimeMillis(0);
long result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 200, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining);
// 50 tokens replenished, acquiring 100 more, should end up with 750 available
clock.setTimeMillis(50);
result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 100, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
// now checking without an update, should not affect the count
result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 100, false),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
}
private Optional<TokenBucket> decodeBucket(final String key) {
try {
final String json = redisCommandsHandler.get(key);
return json == null
? Optional.empty()
: Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
private List<String> scriptArgs(
final long bucketSize,
final long ratePerMillis,
final long requestedAmount,
final boolean useTokens) {
return List.of(
String.valueOf(bucketSize),
String.valueOf(ratePerMillis),
String.valueOf(clock.millis()),
String.valueOf(requestedAmount),
String.valueOf(useTokens)
);
}
private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) {
}
}

View File

@@ -18,9 +18,11 @@ import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.MutableClock;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
@@ -30,8 +32,12 @@ public class RateLimitersTest {
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class);
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
private final MutableClock clock = MockUtils.mutableClock(0);
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
@@ -59,12 +65,12 @@ public class RateLimitersTest {
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
}
@@ -79,18 +85,22 @@ public class RateLimitersTest {
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {});
validateScript,
redisCluster,
clock) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {};
validateScript,
redisCluster,
clock) {};
}
@Test
void testUnchangingConfiguration() {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
@@ -109,7 +119,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(limitsConfigMap);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
@@ -137,7 +147,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(mapForDynamic);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present