Moving RateLimiter logic to Redis Lua and adding async API

This commit is contained in:
Sergey Skrobotov
2023-03-06 13:45:35 -08:00
parent 46fef4082c
commit 4c85e7ba66
17 changed files with 723 additions and 302 deletions

View File

@@ -1,85 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.SystemMapper;
class LeakyBucketTest {
@Test
void testFull() {
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(1));
assertTrue(leakyBucket.add(1));
assertFalse(leakyBucket.add(1));
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(2));
assertFalse(leakyBucket.add(1));
assertFalse(leakyBucket.add(2));
}
@Test
void testLapseRate() throws IOException {
ObjectMapper mapper = SystemMapper.jsonMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertTrue(leakyBucket.add(1));
String serializedAgain = leakyBucket.serialize(mapper);
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
assertFalse(leakyBucketAgain.add(1));
}
@Test
void testLapseShort() throws Exception {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertFalse(leakyBucket.add(1));
}
@Test
void testGetTimeUntilSpaceAvailable() throws Exception {
ObjectMapper mapper = new ObjectMapper();
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
}
{
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
}
}
}

View File

@@ -0,0 +1,147 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.lettuce.core.ScriptOutputType;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox;
import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler;
public class RateLimitersLuaScriptTest {
@RegisterExtension
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
private final MutableClock clock = MockUtils.mutableClock(0);
private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource(
"lua/validate_rate_limit.lua",
ScriptOutputType.INTEGER);
private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock);
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
@Test
public void testWithEmbeddedRedis() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
dynamicConfig,
RateLimiters.defaultScript(redisCluster),
redisCluster,
Clock.systemUTC());
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
rateLimiter.validate("test", 25);
rateLimiter.validate("test", 25);
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
}
@Test
public void testLuaBucketConfigurationUpdates() throws Exception {
final String key = "key1";
clock.setTimeMillis(0);
long result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 1, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
// now making a check-only call, but changing the bucket size
result = (long) sandbox.execute(
List.of(key),
scriptArgs(2000, 1, 1, false),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
}
@Test
public void testLuaUpdatesTokenBucket() throws Exception {
final String key = "key1";
clock.setTimeMillis(0);
long result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 200, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining);
// 50 tokens replenished, acquiring 100 more, should end up with 750 available
clock.setTimeMillis(50);
result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 100, true),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
// now checking without an update, should not affect the count
result = (long) sandbox.execute(
List.of(key),
scriptArgs(1000, 1, 100, false),
redisCommandsHandler
);
assertEquals(0L, result);
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
}
private Optional<TokenBucket> decodeBucket(final String key) {
try {
final String json = redisCommandsHandler.get(key);
return json == null
? Optional.empty()
: Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
private List<String> scriptArgs(
final long bucketSize,
final long ratePerMillis,
final long requestedAmount,
final boolean useTokens) {
return List.of(
String.valueOf(bucketSize),
String.valueOf(ratePerMillis),
String.valueOf(clock.millis()),
String.valueOf(requestedAmount),
String.valueOf(useTokens)
);
}
private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) {
}
}

View File

@@ -18,9 +18,11 @@ import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.MutableClock;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
@@ -30,8 +32,12 @@ public class RateLimitersTest {
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class);
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
private final MutableClock clock = MockUtils.mutableClock(0);
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
@@ -59,12 +65,12 @@ public class RateLimitersTest {
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
}
@@ -79,18 +85,22 @@ public class RateLimitersTest {
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {});
validateScript,
redisCluster,
clock) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {};
validateScript,
redisCluster,
clock) {};
}
@Test
void testUnchangingConfiguration() {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
@@ -109,7 +119,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(limitsConfigMap);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
@@ -137,7 +147,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(mapForDynamic);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.redis;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
import java.util.List;
/**
* This class is to be extended with implementations of Redis commands as needed.
*/
public class BaseRedisCommandsHandler implements RedisCommandsHandler {
@Override
public Object redisCommand(final String command, final List<Object> args) {
return switch (command) {
case "SET" -> {
assertTrue(args.size() > 2);
yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2));
}
case "GET" -> {
assertEquals(1, args.size());
yield get(args.get(0).toString());
}
case "DEL" -> {
assertTrue(args.size() > 1);
yield del(args.get(0).toString());
}
default -> other(command, args);
};
}
public Object set(final String key, final String value, final List<Object> tail) {
return "OK";
}
public String get(final String key) {
return null;
}
public int del(final String key) {
return 0;
}
public Object other(final String command, final List<Object> args) {
return "OK";
}
}

View File

@@ -0,0 +1,14 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.redis;
import java.util.List;
@FunctionalInterface
public interface RedisCommandsHandler {
Object redisCommand(String command, List<Object> args);
}

View File

@@ -0,0 +1,167 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.redis;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.io.Resources;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import party.iroiro.luajava.Lua;
import party.iroiro.luajava.lua51.Lua51;
import party.iroiro.luajava.value.ImmutableLuaValue;
public class RedisLuaScriptSandbox {
private static final String PREFIX = """
function redis_call(...)
-- variable name needs to match the one used in the `L.setGlobal()` call
-- method name needs to match method name of the Java class
return proxy:redisCall(arg)
end
function json_encode(obj)
return mapper:encode(obj)
end
function json_decode(json)
return java.luaify(mapper:decode(json))
end
local redis = { call = redis_call }
local cjson = { encode = json_encode, decode = json_decode }
""";
private final String luaScript;
private final ScriptOutputType scriptOutputType;
public static RedisLuaScriptSandbox fromResource(
final String resource,
final ScriptOutputType scriptOutputType) {
try {
final String src = Resources.toString(Resources.getResource(resource), StandardCharsets.UTF_8);
return new RedisLuaScriptSandbox(src, scriptOutputType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public RedisLuaScriptSandbox(final String luaScript, final ScriptOutputType scriptOutputType) {
this.luaScript = luaScript;
this.scriptOutputType = scriptOutputType;
}
public Object execute(
final List<String> keys,
final List<String> args,
final RedisCommandsHandler redisCallsHandler) {
try (final Lua lua = new Lua51()) {
lua.openLibraries();
final RedisLuaProxy proxy = new RedisLuaProxy(redisCallsHandler);
lua.push(MapperLuaProxy.INSTANCE, Lua.Conversion.FULL);
lua.setGlobal("mapper");
lua.push(proxy, Lua.Conversion.FULL);
lua.setGlobal("proxy");
lua.push(keys, Lua.Conversion.FULL);
lua.setGlobal("KEYS");
lua.push(args, Lua.Conversion.FULL);
lua.setGlobal("ARGV");
final Lua.LuaError executionResult = lua.run(PREFIX + luaScript);
assertEquals("OK", executionResult.name(), "Runtime error during Lua script execution");
return adaptOutputResult(lua.get());
}
}
protected Object adaptOutputResult(final Object luaObject) {
if (luaObject instanceof ImmutableLuaValue<?> luaValue) {
final Object javaValue = luaValue.toJavaObject();
// validate expected script output type
switch (scriptOutputType) {
case INTEGER -> assertTrue(javaValue instanceof Double); // lua number is always Double
case STATUS -> assertTrue(javaValue instanceof String);
case BOOLEAN -> assertTrue(javaValue instanceof Boolean);
};
if (javaValue instanceof Double d) {
return d.longValue();
}
if (javaValue instanceof String s) {
return s;
}
if (javaValue instanceof Boolean b) {
return b;
}
if (javaValue == null) {
return null;
}
throw new IllegalStateException("unexpected script result java type: " + javaValue.getClass().getName());
}
throw new IllegalStateException("unexpected script result lua type: " + luaObject.getClass().getName());
}
public static <T> List<T> tail(final List<T> list, final int fromIdx) {
return fromIdx < list.size() ? list.subList(fromIdx, list.size()) : Collections.emptyList();
}
public static final class MapperLuaProxy {
public static final MapperLuaProxy INSTANCE = new MapperLuaProxy();
public String encode(final Map<Object, Object> obj) {
try {
return SystemMapper.jsonMapper().writeValueAsString(obj);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
public Map<Object, Object> decode(final Object json) {
try {
//noinspection unchecked
return SystemMapper.jsonMapper().readValue(json.toString(), Map.class);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
/**
* Instances of this class are passed to the Lua scripting engine
* and serve as a stubs for the calls to `redis.call()`.
*
* @see #PREFIX
*/
public static final class RedisLuaProxy {
private final RedisCommandsHandler handler;
public RedisLuaProxy(final RedisCommandsHandler handler) {
this.handler = handler;
}
/**
* Method name needs to match the one from the {@link #PREFIX} code.
* The method is getting called from the Lua scripting engine.
*/
@SuppressWarnings("unused")
public Object redisCall(final List<Object> args) {
assertFalse(args.isEmpty(), "`redis.call()` in Lua script invoked without arguments");
assertTrue(args.get(0) instanceof String, "first argument to `redis.call()` must be of type `String`");
return handler.redisCommand((String) args.get(0), tail(args, 1));
}
}
}

View File

@@ -0,0 +1,73 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.redis;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
public record Entry(String value, long expirationEpochMillis) {
}
private final Map<String, Entry> cache = new ConcurrentHashMap<>();
private final Clock clock;
public SimpleCacheCommandsHandler(final Clock clock) {
this.clock = clock;
}
@Override
public Object set(final String key, final String value, final List<Object> tail) {
cache.put(key, new Entry(value, resolveExpirationEpochMillis(tail)));
return "OK";
}
@Override
public String get(final String key) {
final Entry entry = cache.get(key);
if (entry == null) {
return null;
}
if (entry.expirationEpochMillis() < clock.millis()) {
del(key);
return null;
}
return entry.value();
}
@Override
public int del(final String key) {
return cache.remove(key) != null ? 1 : 0;
}
protected long resolveExpirationEpochMillis(final List<Object> args) {
for (int i = 0; i < args.size() - 1; i++) {
final long currentTimeMillis = clock.millis();
final String param = args.get(i).toString();
final String value = args.get(i + 1).toString();
switch (param) {
case "EX" -> {
return currentTimeMillis + Double.valueOf(value).longValue() * 1000;
}
case "PX" -> {
return currentTimeMillis + Double.valueOf(value).longValue();
}
case "EXAT" -> {
return Double.valueOf(value).longValue() * 1000;
}
case "PXAT" -> {
return Double.valueOf(value).longValue();
}
}
}
return Long.MAX_VALUE;
}
}