mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-22 05:18:04 +01:00
Moving RateLimiter logic to Redis Lua and adding async API
This commit is contained in:
@@ -1,85 +0,0 @@
|
||||
/*
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
class LeakyBucketTest {
|
||||
|
||||
@Test
|
||||
void testFull() {
|
||||
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
|
||||
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(2));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testLapseRate() throws IOException {
|
||||
ObjectMapper mapper = SystemMapper.jsonMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertTrue(leakyBucket.add(1));
|
||||
|
||||
String serializedAgain = leakyBucket.serialize(mapper);
|
||||
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
|
||||
|
||||
assertFalse(leakyBucketAgain.add(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testLapseShort() throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertFalse(leakyBucket.add(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetTimeUntilSpaceAvailable() throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
{
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
|
||||
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
|
||||
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
|
||||
}
|
||||
|
||||
{
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
|
||||
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
|
||||
|
||||
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
|
||||
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
|
||||
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.time.Clock;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.MutableClock;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox;
|
||||
import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler;
|
||||
|
||||
public class RateLimitersLuaScriptTest {
|
||||
|
||||
@RegisterExtension
|
||||
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
|
||||
|
||||
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
|
||||
|
||||
private final MutableClock clock = MockUtils.mutableClock(0);
|
||||
|
||||
private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource(
|
||||
"lua/validate_rate_limit.lua",
|
||||
ScriptOutputType.INTEGER);
|
||||
|
||||
private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock);
|
||||
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
|
||||
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
|
||||
|
||||
@Test
|
||||
public void testWithEmbeddedRedis() throws Exception {
|
||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
||||
final RateLimiters limiters = new RateLimiters(
|
||||
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
|
||||
dynamicConfig,
|
||||
RateLimiters.defaultScript(redisCluster),
|
||||
redisCluster,
|
||||
Clock.systemUTC());
|
||||
|
||||
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
|
||||
rateLimiter.validate("test", 25);
|
||||
rateLimiter.validate("test", 25);
|
||||
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLuaBucketConfigurationUpdates() throws Exception {
|
||||
final String key = "key1";
|
||||
clock.setTimeMillis(0);
|
||||
long result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 1, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
|
||||
|
||||
// now making a check-only call, but changing the bucket size
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(2000, 1, 1, false),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLuaUpdatesTokenBucket() throws Exception {
|
||||
final String key = "key1";
|
||||
clock.setTimeMillis(0);
|
||||
long result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 200, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
|
||||
// 50 tokens replenished, acquiring 100 more, should end up with 750 available
|
||||
clock.setTimeMillis(50);
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 100, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
|
||||
// now checking without an update, should not affect the count
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 100, false),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
}
|
||||
|
||||
private Optional<TokenBucket> decodeBucket(final String key) {
|
||||
try {
|
||||
final String json = redisCommandsHandler.get(key);
|
||||
return json == null
|
||||
? Optional.empty()
|
||||
: Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<String> scriptArgs(
|
||||
final long bucketSize,
|
||||
final long ratePerMillis,
|
||||
final long requestedAmount,
|
||||
final boolean useTokens) {
|
||||
return List.of(
|
||||
String.valueOf(bucketSize),
|
||||
String.valueOf(ratePerMillis),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(requestedAmount),
|
||||
String.valueOf(useTokens)
|
||||
);
|
||||
}
|
||||
|
||||
private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) {
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,11 @@ import javax.validation.Valid;
|
||||
import javax.validation.constraints.NotNull;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.MutableClock;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class RateLimitersTest {
|
||||
@@ -30,8 +32,12 @@ public class RateLimitersTest {
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
|
||||
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
|
||||
|
||||
private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class);
|
||||
|
||||
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
|
||||
|
||||
private final MutableClock clock = MockUtils.mutableClock(0);
|
||||
|
||||
private static final String BAD_YAML = """
|
||||
limits:
|
||||
smsVoicePrefix:
|
||||
@@ -59,12 +65,12 @@ public class RateLimitersTest {
|
||||
public void testValidateConfigs() throws Exception {
|
||||
assertThrows(IllegalArgumentException.class, () -> {
|
||||
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
});
|
||||
|
||||
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
}
|
||||
|
||||
@@ -79,18 +85,22 @@ public class RateLimitersTest {
|
||||
new TestDescriptor[] { td1, td2, td3, tdDup },
|
||||
Collections.emptyMap(),
|
||||
dynamicConfig,
|
||||
redisCluster) {});
|
||||
validateScript,
|
||||
redisCluster,
|
||||
clock) {});
|
||||
|
||||
new BaseRateLimiters<>(
|
||||
new TestDescriptor[] { td1, td2, td3 },
|
||||
Collections.emptyMap(),
|
||||
dynamicConfig,
|
||||
redisCluster) {};
|
||||
validateScript,
|
||||
redisCluster,
|
||||
clock) {};
|
||||
}
|
||||
|
||||
@Test
|
||||
void testUnchangingConfiguration() {
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
|
||||
assertEquals(expected, limiter.config());
|
||||
@@ -109,7 +119,7 @@ public class RateLimitersTest {
|
||||
|
||||
when(configuration.getLimits()).thenReturn(limitsConfigMap);
|
||||
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
|
||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
|
||||
@@ -137,7 +147,7 @@ public class RateLimitersTest {
|
||||
|
||||
when(configuration.getLimits()).thenReturn(mapForDynamic);
|
||||
|
||||
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
|
||||
|
||||
// test only default is present
|
||||
|
||||
Reference in New Issue
Block a user