Simplify rate limiters by making them all dynamic

This commit is contained in:
Jonathan Klabunde Tomer
2025-05-21 10:29:26 -07:00
committed by GitHub
parent aafcd63a9f
commit 35604cf151
12 changed files with 449 additions and 472 deletions

View File

@@ -10,16 +10,12 @@ import static java.util.Objects.requireNonNull;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.time.Clock;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@@ -27,25 +23,18 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private final Map<T, RateLimiter> rateLimiterByDescriptor;
private final Map<String, RateLimiterConfig> configs;
protected BaseRateLimiters(
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
@@ -53,22 +42,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
return requireNonNull(rateLimiterByDescriptor.get(handle));
}
public void validateValuesAndConfigs() {
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
.map(RateLimiterDescriptor::id)
.collect(Collectors.toSet());
for (final String key: configs.keySet()) {
if (!ids.contains(key)) {
final String message = String.format(
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
key
);
logger.error(message);
throw new IllegalArgumentException(message);
}
}
}
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
try {
return ClusterLuaScript.fromResource(
@@ -80,21 +53,12 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
return config != null
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
final Supplier<RateLimiterConfig> configResolver =
() -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
}
}

View File

@@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
public class DynamicRateLimiter implements RateLimiter {
private final String name;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final Supplier<RateLimiterConfig> configResolver;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisClusterClient cluster;
private final Clock clock;
private final Counter limitExceededCounter;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
private final Clock clock;
public DynamicRateLimiter(
final String name,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final Supplier<RateLimiterConfig> configResolver,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cluster,
final Clock clock) {
this.name = requireNonNull(name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.configResolver = requireNonNull(configResolver);
this.validateScript = requireNonNull(validateScript);
this.cluster = requireNonNull(cluster);
this.clock = requireNonNull(clock);
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
current().getRight().validate(key, amount);
final RateLimiterConfig config = config();
try {
final long deficitPermitsAmount = executeValidateScript(config, key, amount, true);
if (deficitPermitsAmount > 0) {
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter);
}
} catch (final Exception e) {
if (e instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
}
if (!config.failOpen()) {
throw e;
}
}
}
@Override
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return current().getRight().validateAsync(key, amount);
final RateLimiterConfig config = config();
return executeValidateScriptAsync(config, key, amount, true)
.thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) {
return CompletableFuture.completedFuture((Void) null);
}
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
throw ExceptionUtils.wrap(rateLimitExceededException);
}
if (config.failOpen()) {
return null;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
final RateLimiterConfig config = config();
try {
final long deficitPermitsAmount = executeValidateScript(config, key, permits, false);
return deficitPermitsAmount == 0;
} catch (final Exception e) {
if (config.failOpen()) {
return true;
} else {
throw e;
}
}
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return current().getRight().hasAvailablePermitsAsync(key, amount);
final RateLimiterConfig config = config();
return executeValidateScriptAsync(config, key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> {
if (config.failOpen()) {
return true;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
cluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return current().getRight().clearAsync(key);
return cluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
.thenRun(Util.NOOP);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
return configResolver.get();
}
private Pair<RateLimiterConfig, RateLimiter> current() {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return (Long) validateScript.execute(keys, arguments);
}
private CompletionStage<Long> executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
private static String bucketName(final String name, final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@@ -15,14 +15,9 @@ public interface RateLimiterDescriptor {
*/
String id();
/**
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
*/
boolean isDynamic();
/**
* @return an instance of {@link RateLimiterConfig} to be used by default,
* i.e. if there is no overrides in the application configuration files (static or dynamic).
* i.e. if there is no override in the application dynamic configuration.
*/
RateLimiterConfig defaultConfig();
}

View File

@@ -4,11 +4,9 @@
*/
package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
@@ -17,57 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum For implements RateLimiterDescriptor {
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1), false)),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1), false)),
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7), false)),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)),
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)),
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)),
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
;
private final String id;
private final boolean dynamic;
private final RateLimiterConfig defaultConfig;
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
For(final String id, final RateLimiterConfig defaultConfig) {
this.id = id;
this.dynamic = dynamic;
this.defaultConfig = defaultConfig;
}
@@ -75,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
return id;
}
@Override
public boolean isDynamic() {
return dynamic;
}
public RateLimiterConfig defaultConfig() {
return defaultConfig;
}
}
public static RateLimiters createAndValidate(
final Map<String, RateLimiterConfig> configs,
public static RateLimiters create(
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisClusterClient cacheCluster) {
final RateLimiters rateLimiters = new RateLimiters(
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
return new RateLimiters(
dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
}
@VisibleForTesting
RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock);
}
public RateLimiter getAllocateDeviceLimiter() {

View File

@@ -1,170 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util;
public class StaticRateLimiter implements RateLimiter {
protected final String name;
private final RateLimiterConfig config;
private final Counter limitExceededCounter;
private final ClusterLuaScript validateScript;
private final FaultTolerantRedisClusterClient cacheCluster;
private final Clock clock;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock) {
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster);
this.clock = requireNonNull(clock);
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
try {
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
if (deficitPermitsAmount > 0) {
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter);
}
} catch (final Exception e) {
if (e instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
}
if (!config.failOpen()) {
throw e;
}
}
}
@Override
public CompletionStage<Void> validateAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, true)
.thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) {
return completedFuture((Void) null);
}
limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return failedFuture(new RateLimitExceededException(retryAfter));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
throw ExceptionUtils.wrap(rateLimitExceededException);
}
if (config.failOpen()) {
return null;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public boolean hasAvailablePermits(final String key, final int amount) {
try {
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
return deficitPermitsAmount == 0;
} catch (final Exception e) {
if (config.failOpen()) {
return true;
} else {
throw e;
}
}
}
@Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> {
if (config.failOpen()) {
return true;
}
throw ExceptionUtils.wrap(throwable);
});
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
}
@Override
public CompletionStage<Void> clearAsync(final String key) {
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
.thenRun(Util.NOOP);
}
@Override
public RateLimiterConfig config() {
return config;
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return (Long) validateScript.execute(keys, arguments);
}
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of(
String.valueOf(config.bucketSize()),
String.valueOf(config.leakRatePerMillis()),
String.valueOf(clock.millis()),
String.valueOf(amount),
String.valueOf(applyChanges)
);
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
@VisibleForTesting
protected static String bucketName(final String name, final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}