mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-20 05:28:05 +01:00
Move rate limiter logic to Lua scripts
This commit is contained in:
@@ -1,98 +0,0 @@
|
||||
/**
|
||||
* Copyright (C) 2013 Open WhisperSystems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class LeakyBucket {
|
||||
|
||||
private final int bucketSize;
|
||||
private final double leakRatePerMillis;
|
||||
|
||||
private int spaceRemaining;
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
|
||||
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
|
||||
}
|
||||
|
||||
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
|
||||
public boolean add(int amount) {
|
||||
this.spaceRemaining = getUpdatedSpaceRemaining();
|
||||
this.lastUpdateTimeMillis = System.currentTimeMillis();
|
||||
|
||||
if (this.spaceRemaining >= amount) {
|
||||
this.spaceRemaining -= amount;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private int getUpdatedSpaceRemaining() {
|
||||
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
|
||||
|
||||
return Math.min(this.bucketSize,
|
||||
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
||||
}
|
||||
|
||||
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
|
||||
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
|
||||
}
|
||||
|
||||
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
|
||||
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
|
||||
|
||||
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
|
||||
entity.spaceRemaining, entity.lastUpdateTimeMillis);
|
||||
}
|
||||
|
||||
private static class LeakyBucketEntity {
|
||||
@JsonProperty
|
||||
private int bucketSize;
|
||||
|
||||
@JsonProperty
|
||||
private double leakRatePerMillis;
|
||||
|
||||
@JsonProperty
|
||||
private int spaceRemaining;
|
||||
|
||||
@JsonProperty
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucketEntity() {}
|
||||
|
||||
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
|
||||
int spaceRemaining, long lastUpdateTimeMillis)
|
||||
{
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import com.codahale.metrics.Meter;
|
||||
import com.codahale.metrics.MetricRegistry;
|
||||
import com.codahale.metrics.SharedMetricRegistries;
|
||||
import io.lettuce.core.SetArgs;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
import redis.clients.jedis.Jedis;
|
||||
|
||||
public class LockingRateLimiter extends RateLimiter {
|
||||
|
||||
private final Meter meter;
|
||||
|
||||
public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
|
||||
super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
|
||||
|
||||
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
if (!acquireLock(key)) {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException("Locked");
|
||||
}
|
||||
|
||||
try {
|
||||
super.validate(key, amount);
|
||||
} finally {
|
||||
releaseLock(key);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(String key) throws RateLimitExceededException {
|
||||
validate(key, 1);
|
||||
}
|
||||
|
||||
private void releaseLock(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String lockName = getLockName(key);
|
||||
|
||||
jedis.del(lockName);
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
|
||||
}
|
||||
}
|
||||
|
||||
private boolean acquireLock(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String lockName = getLockName(key);
|
||||
|
||||
final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
|
||||
|
||||
if (acquiredLock) {
|
||||
// TODO Restore the NX flag when the cluster becomes the primary source of truth
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
|
||||
}
|
||||
|
||||
return acquiredLock;
|
||||
}
|
||||
}
|
||||
|
||||
private String getLockName(String key) {
|
||||
return "leaky_lock::" + name + "::" + key;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -20,26 +20,25 @@ import com.codahale.metrics.Meter;
|
||||
import com.codahale.metrics.MetricRegistry;
|
||||
import com.codahale.metrics.SharedMetricRegistries;
|
||||
import com.codahale.metrics.Timer;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.experiment.Experiment;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.LuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
import redis.clients.jedis.Jedis;
|
||||
|
||||
public class RateLimiter {
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
|
||||
public class RateLimiter {
|
||||
|
||||
private final Meter meter;
|
||||
private final Timer validateTimer;
|
||||
@@ -48,19 +47,12 @@ public class RateLimiter {
|
||||
protected final String name;
|
||||
private final int bucketSize;
|
||||
private final double leakRatePerMillis;
|
||||
private final boolean reportLimits;
|
||||
private final Experiment redisClusterExperiment;
|
||||
private final LuaScript validateScript;
|
||||
private final ClusterLuaScript clusterValidateScript;
|
||||
|
||||
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||
int bucketSize, double leakRatePerMinute)
|
||||
{
|
||||
this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
|
||||
}
|
||||
|
||||
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||
int bucketSize, double leakRatePerMinute,
|
||||
boolean reportLimits)
|
||||
{
|
||||
int bucketSize, double leakRatePerMinute) throws IOException {
|
||||
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
|
||||
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
|
||||
@@ -70,27 +62,37 @@ public class RateLimiter {
|
||||
this.name = name;
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
|
||||
this.reportLimits = reportLimits;
|
||||
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
|
||||
}
|
||||
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
try (final Timer.Context ignored = validateTimer.time()) {
|
||||
LeakyBucket bucket = getBucket(key);
|
||||
|
||||
if (bucket.add(amount)) {
|
||||
setBucket(key, bucket);
|
||||
} else {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException(key + " , " + amount);
|
||||
}
|
||||
}
|
||||
this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
|
||||
this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
||||
}
|
||||
|
||||
public void validate(String key) throws RateLimitExceededException {
|
||||
validate(key, 1);
|
||||
}
|
||||
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
validate(key, amount, System.currentTimeMillis());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
|
||||
try (final Timer.Context ignored = validateTimer.time()) {
|
||||
final List<String> keys = List.of(getBucketName(key));
|
||||
final List<String> arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
|
||||
|
||||
final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
|
||||
arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
|
||||
|
||||
redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
|
||||
|
||||
if (result == null) {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException(key + " , " + amount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void clear(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
@@ -100,37 +102,8 @@ public class RateLimiter {
|
||||
}
|
||||
}
|
||||
|
||||
private void setBucket(String key, LeakyBucket bucket) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
final String serialized = bucket.serialize(mapper);
|
||||
final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
|
||||
|
||||
jedis.setex(bucketName, level, serialized);
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private LeakyBucket getBucket(String key) {
|
||||
try (Jedis jedis = cacheClient.getReadResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
|
||||
String serialized = jedis.get(bucketName);
|
||||
redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
|
||||
|
||||
if (serialized != null) {
|
||||
return LeakyBucket.fromSerialized(mapper, serialized);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.warn("Deserialization error", e);
|
||||
}
|
||||
|
||||
return new LeakyBucket(bucketSize, leakRatePerMillis);
|
||||
}
|
||||
|
||||
private String getBucketName(String key) {
|
||||
@VisibleForTesting
|
||||
String getBucketName(String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RateLimiters {
|
||||
|
||||
private final RateLimiter smsDestinationLimiter;
|
||||
@@ -48,7 +50,7 @@ public class RateLimiters {
|
||||
private final RateLimiter usernameLookupLimiter;
|
||||
private final RateLimiter usernameSetLimiter;
|
||||
|
||||
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
|
||||
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
|
||||
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
|
||||
config.getSmsDestination().getBucketSize(),
|
||||
config.getSmsDestination().getLeakRatePerMinute());
|
||||
@@ -73,11 +75,11 @@ public class RateLimiters {
|
||||
config.getAutoBlock().getBucketSize(),
|
||||
config.getAutoBlock().getLeakRatePerMinute());
|
||||
|
||||
this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
|
||||
this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
|
||||
config.getVerifyNumber().getBucketSize(),
|
||||
config.getVerifyNumber().getLeakRatePerMinute());
|
||||
|
||||
this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
|
||||
this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
|
||||
config.getVerifyPin().getBucketSize(),
|
||||
config.getVerifyPin().getLeakRatePerMinute());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user