Move rate limiter logic to Lua scripts

This commit is contained in:
Jon Chambers
2020-07-06 10:10:13 -04:00
committed by GitHub
parent f5ddb0f1f8
commit b585c6676d
8 changed files with 214 additions and 294 deletions

View File

@@ -0,0 +1,88 @@
package org.whispersystems.textsecuregcm.limits;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.AbstractRedisTest;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import redis.clients.jedis.Jedis;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
public class RateLimiterTest extends AbstractRedisTest {
private static final long NOW_MILLIS = System.currentTimeMillis();
private static final String KEY = "key";
@FunctionalInterface
private interface RateLimitedTask {
void run() throws RateLimitExceededException;
}
@Before
public void clearCache() {
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.flushAll();
}
}
@Test
public void validate() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
rateLimiter.validate(KEY, 1, NOW_MILLIS);
rateLimiter.validate(KEY, 1, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void validateWithAmount() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
rateLimiter.validate(KEY, 2, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void testLapseRate() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
}
rateLimiter.validate(KEY, 1, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void testLapseShort() throws IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
}
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
private void assertRateLimitExceeded(final RateLimitedTask task) {
try {
task.run();
fail("Expected RateLimitExceededException");
} catch (final RateLimitExceededException ignored) {
}
}
@SuppressWarnings("SameParameterValue")
private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
final double leakRatePerMinute = leakRatePerMilli * 60_000d;
return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute);
}
}

View File

@@ -0,0 +1,48 @@
package org.whispersystems.textsecuregcm.redis;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import redis.embedded.RedisServer;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.URISyntaxException;
import java.util.List;
public abstract class AbstractRedisTest {
private static RedisServer redisServer;
private ReplicatedJedisPool replicatedJedisPool;
@BeforeClass
public static void setUpBeforeClass() throws IOException {
redisServer = new RedisServer(getNextPort());
redisServer.start();
}
@Before
public void setUp() throws URISyntaxException {
final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
}
protected ReplicatedJedisPool getReplicatedJedisPool() {
return replicatedJedisPool;
}
@AfterClass
public static void tearDownAfterClass() {
redisServer.stop();
}
private static int getNextPort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(false);
return socket.getLocalPort();
}
}
}

View File

@@ -1,52 +0,0 @@
package org.whispersystems.textsecuregcm.tests.limits;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.Test;
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class LeakyBucketTest {
@Test
public void testFull() {
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(1));
assertTrue(leakyBucket.add(1));
assertFalse(leakyBucket.add(1));
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(2));
assertFalse(leakyBucket.add(1));
assertFalse(leakyBucket.add(2));
}
@Test
public void testLapseRate() throws IOException {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertTrue(leakyBucket.add(1));
String serializedAgain = leakyBucket.serialize(mapper);
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
assertFalse(leakyBucketAgain.add(1));
}
@Test
public void testLapseShort() throws Exception {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertFalse(leakyBucket.add(1));
}
}