Moving RateLimiter logic to Redis Lua and adding async API

This commit is contained in:
Sergey Skrobotov
2023-03-06 13:45:35 -08:00
parent 46fef4082c
commit 4c85e7ba66
17 changed files with 723 additions and 302 deletions

View File

@@ -7,7 +7,11 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.time.Clock;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
@@ -17,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -33,12 +38,14 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cacheCluster,
final Clock clock) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
@@ -62,11 +69,22 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
}
}
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisCluster cacheCluster) {
try {
return ClusterLuaScript.fromResource(
cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
} catch (final IOException e) {
throw new UncheckedIOException("Failed to load rate limit validation script", e);
}
}
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cacheCluster,
final Clock clock) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
@@ -74,9 +92,9 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
}
}

View File

@@ -7,10 +7,13 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.time.Clock;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class DynamicRateLimiter implements RateLimiter {
@@ -19,18 +22,26 @@ public class DynamicRateLimiter implements RateLimiter {
private final Supplier<RateLimiterConfig> configResolver;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisCluster cluster;
private final Clock clock;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
public DynamicRateLimiter(
final String name,
final Supplier<RateLimiterConfig> configResolver,
final FaultTolerantRedisCluster cluster) {
final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cluster,
final Clock clock) {
this.name = requireNonNull(name);
this.configResolver = requireNonNull(configResolver);
this.validateScript = requireNonNull(validateScript);
this.cluster = requireNonNull(cluster);
this.clock = requireNonNull(clock);
}
@Override
@@ -38,16 +49,31 @@ public class DynamicRateLimiter implements RateLimiter {
current().getRight().validate(key, amount);
}
@Override
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return current().getRight().validateAsync(key, amount);
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return current().getRight().hasAvailablePermitsAsync(key, amount);
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return current().getRight().clearAsync(key);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
@@ -57,7 +83,7 @@ public class DynamicRateLimiter implements RateLimiter {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
);
}
}

View File

@@ -1,99 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
public class LeakyBucket {
private final int bucketSize;
private final double leakRatePerMillis;
private int spaceRemaining;
private long lastUpdateTimeMillis;
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
}
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMillis;
this.spaceRemaining = spaceRemaining;
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
}
public boolean add(int amount) {
this.spaceRemaining = getUpdatedSpaceRemaining();
this.lastUpdateTimeMillis = System.currentTimeMillis();
if (this.spaceRemaining >= amount) {
this.spaceRemaining -= amount;
return true;
} else {
return false;
}
}
private int getUpdatedSpaceRemaining() {
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
return Math.min(this.bucketSize,
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
}
public Duration getTimeUntilSpaceAvailable(int amount) {
int currentSpaceRemaining = getUpdatedSpaceRemaining();
if (currentSpaceRemaining >= amount) {
return Duration.ZERO;
} else if (amount > this.bucketSize) {
// This shouldn't happen today but if so we should bubble this to the clients somehow
throw new IllegalArgumentException("Requested permits exceed maximum bucket size");
} else {
return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis));
}
}
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
}
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
entity.spaceRemaining, entity.lastUpdateTimeMillis);
}
private static class LeakyBucketEntity {
@JsonProperty
private int bucketSize;
@JsonProperty
private double leakRatePerMillis;
@JsonProperty
private int spaceRemaining;
@JsonProperty
private long lastUpdateTimeMillis;
public LeakyBucketEntity() {}
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
int spaceRemaining, long lastUpdateTimeMillis)
{
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMillis;
this.spaceRemaining = spaceRemaining;
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
}
}
}

View File

@@ -1,58 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.lettuce.core.SetArgs;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Duration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class LockingRateLimiter extends StaticRateLimiter {
private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
= new RateLimitExceededException(Duration.ZERO, true);
private final Counter counter;
public LockingRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
super(name, config, cacheCluster);
this.counter = Metrics.counter(name(getClass(), "locked"), "name", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
if (!acquireLock(key)) {
counter.increment();
throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
}
try {
super.validate(key, amount);
} finally {
releaseLock(key);
}
}
private void releaseLock(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
}
private boolean acquireLock(final String key) {
return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
}
private String getLockName(final String key) {
return "leaky_lock::" + name + "::" + key;
}
}

View File

@@ -6,16 +6,23 @@
package org.whispersystems.textsecuregcm.limits;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public interface RateLimiter {
void validate(String key, int amount) throws RateLimitExceededException;
CompletionStage<Void> validateAsync(String key, int amount);
boolean hasAvailablePermits(String key, int permits);
CompletionStage<Boolean> hasAvailablePermitsAsync(String key, int amount);
void clear(String key);
CompletionStage<Void> clearAsync(String key);
RateLimiterConfig config();
default void validate(final String key) throws RateLimitExceededException {
@@ -30,14 +37,34 @@ public interface RateLimiter {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
default CompletionStage<Void> validateAsync(final String key) {
return validateAsync(key, 1);
}
default CompletionStage<Void> validateAsync(final UUID accountUuid) {
return validateAsync(accountUuid.toString());
}
default CompletionStage<Void> validateAsync(final UUID srcAccountUuid, final UUID dstAccountUuid) {
return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits);
}
default CompletionStage<Boolean> hasAvailablePermitsAsync(final UUID accountUuid, final int permits) {
return hasAvailablePermitsAsync(accountUuid.toString(), permits);
}
default void clear(final UUID accountUuid) {
clear(accountUuid.toString());
}
default CompletionStage<Void> clearAsync(final UUID accountUuid) {
return clearAsync(accountUuid.toString());
}
/**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true}

View File

@@ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.util.Map;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -103,7 +105,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
final RateLimiters rateLimiters = new RateLimiters(
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
}
@@ -112,8 +115,10 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cacheCluster,
final Clock clock) {
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
}
public RateLimiter getAllocateDeviceLimiter() {

View File

@@ -5,60 +5,96 @@
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
public class StaticRateLimiter implements RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
protected final String name;
private final RateLimiterConfig config;
protected final FaultTolerantRedisCluster cacheCluster;
private final Counter counter;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisCluster cacheCluster;
private final Clock clock;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cacheCluster,
final Clock clock) {
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster);
this.counter = Metrics.counter(name(getClass(), "exceeded"), "name", name);
this.clock = requireNonNull(clock);
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
final LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
if (deficitPermitsAmount > 0) {
counter.increment();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter, true);
}
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, true)
.thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) {
return completedFuture(null);
}
counter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return failedFuture(new RateLimitExceededException(retryAfter, true));
});
}
@Override
public boolean hasAvailablePermits(final String key, final int amount) {
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
return deficitPermitsAmount == 0;
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0);
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key)))
.thenRun(Util.NOOP);
}
@Override
@@ -66,33 +102,31 @@ public class StaticRateLimiter implements RateLimiter {
return config;
}
private void setBucket(final String key, final LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(SystemMapper.jsonMapper());
cacheCluster.useCluster(connection -> connection.sync().setex(
getBucketName(key),
(int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
serialized));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return (Long) validateScript.execute(keys, arguments);
}
private LeakyBucket getBucket(final String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(SystemMapper.jsonMapper(), serialized);
}
} catch (final IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
private String getBucketName(final String key) {
private String bucketName(final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}