Rate limiters code refactored

This commit is contained in:
Sergey Skrobotov
2023-02-23 10:21:39 -08:00
parent 378b32d44d
commit 7529c35013
35 changed files with 738 additions and 774 deletions

View File

@@ -0,0 +1,82 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private final Map<T, RateLimiter> rateLimiterByDescriptor;
private final Map<String, RateLimiterConfig> configs;
protected BaseRateLimiters(
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
public RateLimiter forDescriptor(final T handle) {
return requireNonNull(rateLimiterByDescriptor.get(handle));
}
public void validateValuesAndConfigs() {
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
.map(RateLimiterDescriptor::id)
.collect(Collectors.toSet());
for (final String key: configs.keySet()) {
if (!ids.contains(key)) {
final String message = String.format(
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
key
);
logger.error(message);
throw new IllegalArgumentException(message);
}
}
}
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
return config != null
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
}
}

View File

@@ -0,0 +1,63 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class DynamicRateLimiter implements RateLimiter {
private final String name;
private final Supplier<RateLimiterConfig> configResolver;
private final FaultTolerantRedisCluster cluster;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
public DynamicRateLimiter(
final String name,
final Supplier<RateLimiterConfig> configResolver,
final FaultTolerantRedisCluster cluster) {
this.name = requireNonNull(name);
this.configResolver = requireNonNull(configResolver);
this.cluster = requireNonNull(cluster);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
current().getRight().validate(key, amount);
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
}
private Pair<RateLimiterConfig, RateLimiter> current() {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
);
}
}

View File

@@ -1,130 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimiters {
private final FaultTolerantRedisCluster cacheCluster;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final AtomicReference<RateLimiter> rateLimitResetLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeSuccessLimiter;
private final AtomicReference<RateLimiter> pushChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> pushChallengeSuccessLimiter;
public DynamicRateLimiters(final FaultTolerantRedisCluster rateLimitCluster,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.cacheCluster = rateLimitCluster;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitResetLimiter = new AtomicReference<>(
createRateLimitResetLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset()));
this.recaptchaChallengeAttemptLimiter = new AtomicReference<>(createRecaptchaChallengeAttemptLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt()));
this.recaptchaChallengeSuccessLimiter = new AtomicReference<>(createRecaptchaChallengeSuccessLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess()));
this.pushChallengeAttemptLimiter = new AtomicReference<>(createPushChallengeAttemptLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt()));
this.pushChallengeSuccessLimiter = new AtomicReference<>(createPushChallengeSuccessLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess()));
}
public RateLimiter getRateLimitResetLimiter() {
return updateAndGetRateLimiter(
rateLimitResetLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset(),
this::createRateLimitResetLimiter);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt(),
this::createRecaptchaChallengeAttemptLimiter);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess(),
this::createRecaptchaChallengeSuccessLimiter);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
pushChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt(),
this::createPushChallengeAttemptLimiter);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
pushChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess(),
this::createPushChallengeSuccessLimiter);
}
private RateLimiter updateAndGetRateLimiter(final AtomicReference<RateLimiter> rateLimiter,
RateLimitConfiguration currentConfiguration,
BiFunction<FaultTolerantRedisCluster, RateLimitConfiguration, RateLimiter> rateLimitFactory) {
return rateLimiter.updateAndGet(limiter -> {
if (limiter.hasConfiguration(currentConfiguration)) {
return limiter;
} else {
return rateLimitFactory.apply(cacheCluster, currentConfiguration);
}
});
}
public RateLimiter createRateLimitResetLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "rateLimitReset");
}
public RateLimiter createRecaptchaChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeAttempt");
}
public RateLimiter createRecaptchaChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeSuccess");
}
public RateLimiter createPushChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeAttempt");
}
public RateLimiter createPushChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeSuccess");
}
private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration,
String name) {
return new RateLimiter(cacheCluster, name,
configuration.getBucketSize(),
configuration.getLeakRatePerMinute());
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -16,22 +16,28 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
public class LockingRateLimiter extends RateLimiter {
public class LockingRateLimiter extends StaticRateLimiter {
private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
= new RateLimitExceededException(Duration.ZERO, true);
private final Meter meter;
public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheCluster, name, bucketSize, leakRatePerMinute);
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
public LockingRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
super(name, config, cacheCluster);
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
}
@Override
public void validate(String key, int amount) throws RateLimitExceededException {
public void validate(final String key, final int amount) throws RateLimitExceededException {
if (!acquireLock(key)) {
meter.mark();
throw new RateLimitExceededException(Duration.ZERO, true);
throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
}
try {
@@ -41,22 +47,15 @@ public class LockingRateLimiter extends RateLimiter {
}
}
@Override
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
private void releaseLock(String key) {
private void releaseLock(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
}
private boolean acquireLock(String key) {
private boolean acquireLock(final String key) {
return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
}
private String getLockName(String key) {
private String getLockName(final String key) {
return "leaky_lock::" + name + "::" + key;
}
}

View File

@@ -58,7 +58,7 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return;
}
final RateLimiters.Handle handle = annotation.value();
final RateLimiters.For handle = annotation.value();
try {
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
@@ -77,13 +77,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return;
}
final Optional<RateLimiter> maybeRateLimiter = rateLimiters.byHandle(handle);
if (maybeRateLimiter.isEmpty()) {
logger.warn("RateLimiter not found for {}. Make sure it's initialized in RateLimiters class", handle);
return;
}
maybeRateLimiter.get().validate(maybeMostRecentProxy.get());
final RateLimiter rateLimiter = rateLimiters.forDescriptor(handle);
rateLimiter.validate(maybeMostRecentProxy.get());
} catch (RateLimitExceededException e) {
final Response response = EXCEPTION_MAPPER.toResponse(e);
throw new ClientErrorException(response);

View File

@@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
@@ -9,11 +14,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
@@ -21,7 +26,7 @@ public class RateLimitChallengeManager {
private final PushChallengeManager pushChallengeManager;
private final CaptchaChecker captchaChecker;
private final DynamicRateLimiters rateLimiters;
private final RateLimiters rateLimiters;
private final List<RateLimitChallengeListener> rateLimitChallengeListeners =
Collections.synchronizedList(new ArrayList<>());
@@ -35,7 +40,7 @@ public class RateLimitChallengeManager {
public RateLimitChallengeManager(
final PushChallengeManager pushChallengeManager,
final CaptchaChecker captchaChecker,
final DynamicRateLimiters rateLimiters) {
final RateLimiters rateLimiters) {
this.pushChallengeManager = pushChallengeManager;
this.captchaChecker = captchaChecker;

View File

@@ -1,30 +1,30 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import com.vdurmont.semver4j.Semver;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class RateLimitChallengeOptionManager {
private final DynamicRateLimiters rateLimiters;
private final RateLimiters rateLimiters;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
public static final String OPTION_RECAPTCHA = "recaptcha";
public static final String OPTION_PUSH_CHALLENGE = "pushChallenge";
public RateLimitChallengeOptionManager(final DynamicRateLimiters rateLimiters,
public RateLimitChallengeOptionManager(final RateLimiters rateLimiters,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.rateLimiters = rateLimiters;

View File

@@ -14,7 +14,7 @@ import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimitedByIp {
RateLimiters.Handle value();
RateLimiters.For value();
boolean failOnUnresolvedIp() default true;
}

View File

@@ -1,143 +1,48 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class RateLimiter {
public interface RateLimiter {
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
void validate(String key, int amount) throws RateLimitExceededException;
private final Meter meter;
private final Timer validateTimer;
protected final FaultTolerantRedisCluster cacheCluster;
protected final String name;
private final int bucketSize;
private final double leakRatePerMinute;
private final double leakRatePerMillis;
boolean hasAvailablePermits(String key, int permits);
public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute)
{
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
void clear(String key);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
this.cacheCluster = cacheCluster;
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
}
RateLimiterConfig config();
public void validate(String key, int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
public void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
public void validate(final UUID sourceAccountUuid, final UUID destinationAccountUuid)
throws RateLimitExceededException {
validate(sourceAccountUuid.toString() + "__" + destinationAccountUuid.toString());
}
public void validate(String key) throws RateLimitExceededException {
default void validate(final String key) throws RateLimitExceededException {
validate(key, 1);
}
public boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
default void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits);
}
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
public void clear(final UUID accountUuid) {
default void clear(final UUID accountUuid) {
clear(accountUuid.toString());
}
public void clear(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
public int getBucketSize() {
return bucketSize;
}
public double getLeakRatePerMinute() {
return leakRatePerMinute;
}
private void setBucket(String key, LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(mapper);
cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);
}
} catch (IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(bucketSize, leakRatePerMillis);
}
private String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
public boolean hasConfiguration(final RateLimitConfiguration configuration) {
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
}
/**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
*/
public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
try {
validator.validate();
} catch (final RateLimitExceededException e) {
@@ -146,9 +51,8 @@ public class RateLimiter {
}
@FunctionalInterface
public interface RateLimitValidator {
interface RateLimitValidator {
void validate() throws RateLimitExceededException;
}
}

View File

@@ -0,0 +1,13 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
public record RateLimiterConfig(int bucketSize, double leakRatePerMinute) {
public double leakRatePerMillis() {
return leakRatePerMinute / (60.0 * 1000.0);
}
}

View File

@@ -0,0 +1,28 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
/**
* Represents an information that defines a rate limiter.
*/
public interface RateLimiterDescriptor {
/**
* Implementing classes will likely be Enums, so name is chosen not to clash with {@link Enum#name()}.
* @return id of this rate limiter to be used in `yml` config files and as a part of the bucket key.
*/
String id();
/**
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
*/
boolean isDynamic();
/**
* @return an instance of {@link RateLimiterConfig} to be used by default,
* i.e. if there is no overrides in the application configuration files (static or dynamic).
*/
RateLimiterConfig defaultConfig();
}

View File

@@ -5,193 +5,232 @@
package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters {
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum Handle {
USERNAME_LOOKUP("usernameLookup"),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence"),
BACKUP_AUTH_CHECK;
public enum For implements RateLimiterDescriptor {
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
SMS_DESTINATION("smsDestination", false, new RateLimiterConfig(2, 2)),
VOICE_DESTINATION("voxDestination", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VOICE_DESTINATION_DAILY("voxDestinationDaily", false, new RateLimiterConfig(10, 10.0 / (24.0 * 60.0))),
SMS_VOICE_IP("smsVoiceIp", false, new RateLimiterConfig(1000, 1000)),
SMS_VOICE_PREFIX("smsVoicePrefix", false, new RateLimiterConfig(1000, 1000)),
VERIFY("verify", false, new RateLimiterConfig(2, 2)),
PIN("pin", false, new RateLimiterConfig(10, 1 / (24.0 * 60.0))),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, 50)),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, 1.0 / 10.0)),
MESSAGES("messages", false, new RateLimiterConfig(60, 60)),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, 1.0 / 10.0)),
TURN("turnAllocate", false, new RateLimiterConfig(60, 60)),
PROFILE("profile", false, new RateLimiterConfig(4320, 3)),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
ART_PACK("artPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1_000, 1_000 / 60.0)),
STORIES("stories", false, new RateLimiterConfig(10_000, 10_000 / (24.0 * 60.0))),
REGISTRATION("registration", false, new RateLimiterConfig(2, 2)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, 2)),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, 2)),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_ATTEMPT("recaptchaChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
;
private final String id;
private final boolean dynamic;
Handle(final String id) {
private final RateLimiterConfig defaultConfig;
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
this.id = id;
}
Handle() {
this.id = name();
this.dynamic = dynamic;
this.defaultConfig = defaultConfig;
}
public String id() {
return id;
}
@Override
public boolean isDynamic() {
return dynamic;
}
public RateLimiterConfig defaultConfig() {
return defaultConfig;
}
}
private final RateLimiter smsDestinationLimiter;
private final RateLimiter voiceDestinationLimiter;
private final RateLimiter voiceDestinationDailyLimiter;
private final RateLimiter smsVoiceIpLimiter;
private final RateLimiter smsVoicePrefixLimiter;
private final RateLimiter verifyLimiter;
private final RateLimiter verificationCaptchaLimiter;
private final RateLimiter verificationPushChallengeLimiter;
private final RateLimiter pinLimiter;
private final RateLimiter registrationLimiter;
private final RateLimiter attachmentLimiter;
private final RateLimiter preKeysLimiter;
private final RateLimiter messagesLimiter;
private final RateLimiter allocateDeviceLimiter;
private final RateLimiter verifyDeviceLimiter;
private final RateLimiter turnLimiter;
private final RateLimiter profileLimiter;
private final RateLimiter stickerPackLimiter;
private final RateLimiter artPackLimiter;
private final RateLimiter usernameSetLimiter;
private final RateLimiter usernameReserveLimiter;
private final Map<String, RateLimiter> rateLimiterByHandle;
public RateLimiters(final RateLimitsConfiguration config, final FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = fromConfig("smsDestination", config.getSmsDestination(), cacheCluster);
this.voiceDestinationLimiter = fromConfig("voxDestination", config.getVoiceDestination(), cacheCluster);
this.voiceDestinationDailyLimiter = fromConfig("voxDestinationDaily", config.getVoiceDestinationDaily(),
cacheCluster);
this.smsVoiceIpLimiter = fromConfig("smsVoiceIp", config.getSmsVoiceIp(), cacheCluster);
this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster);
this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster);
this.verificationCaptchaLimiter = fromConfig("verificationCaptcha", config.getVerificationCaptcha(), cacheCluster);
this.verificationPushChallengeLimiter = fromConfig("verificationPushChallenge",
config.getVerificationPushChallenge(), cacheCluster);
this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster);
this.registrationLimiter = fromConfig("registration", config.getRegistration(), cacheCluster);
this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster);
this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster);
this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster);
this.allocateDeviceLimiter = fromConfig("allocateDevice", config.getAllocateDevice(), cacheCluster);
this.verifyDeviceLimiter = fromConfig("verifyDevice", config.getVerifyDevice(), cacheCluster);
this.turnLimiter = fromConfig("turnAllocate", config.getTurnAllocations(), cacheCluster);
this.profileLimiter = fromConfig("profile", config.getProfile(), cacheCluster);
this.stickerPackLimiter = fromConfig("stickerPack", config.getStickerPack(), cacheCluster);
this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster);
this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster);
this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster);
this.rateLimiterByHandle = Stream.of(
fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster),
fromConfig(Handle.CHECK_ACCOUNT_EXISTENCE.id(), config.getCheckAccountExistence(), cacheCluster),
fromConfig(Handle.USERNAME_LOOKUP.id(), config.getUsernameLookup(), cacheCluster)
).map(rl -> Pair.of(rl.name, rl)).collect(Collectors.toMap(Pair::getKey, Pair::getValue));
public static RateLimiters createAndValidate(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
}
public Optional<RateLimiter> byHandle(final Handle handle) {
return Optional.ofNullable(rateLimiterByHandle.get(handle.id()));
@VisibleForTesting
RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
}
public RateLimiter getAllocateDeviceLimiter() {
return allocateDeviceLimiter;
return forDescriptor(For.ALLOCATE_DEVICE);
}
public RateLimiter getVerifyDeviceLimiter() {
return verifyDeviceLimiter;
return forDescriptor(For.VERIFY_DEVICE);
}
public RateLimiter getMessagesLimiter() {
return messagesLimiter;
return forDescriptor(For.MESSAGES);
}
public RateLimiter getPreKeysLimiter() {
return preKeysLimiter;
return forDescriptor(For.PRE_KEYS);
}
public RateLimiter getAttachmentLimiter() {
return this.attachmentLimiter;
return forDescriptor(For.ATTACHMENT);
}
public RateLimiter getSmsDestinationLimiter() {
return smsDestinationLimiter;
return forDescriptor(For.SMS_DESTINATION);
}
public RateLimiter getSmsVoiceIpLimiter() {
return smsVoiceIpLimiter;
return forDescriptor(For.SMS_VOICE_IP);
}
public RateLimiter getSmsVoicePrefixLimiter() {
return smsVoicePrefixLimiter;
return forDescriptor(For.SMS_VOICE_PREFIX);
}
public RateLimiter getVoiceDestinationLimiter() {
return voiceDestinationLimiter;
return forDescriptor(For.VOICE_DESTINATION);
}
public RateLimiter getVoiceDestinationDailyLimiter() {
return voiceDestinationDailyLimiter;
return forDescriptor(For.VOICE_DESTINATION_DAILY);
}
public RateLimiter getVerifyLimiter() {
return verifyLimiter;
}
public RateLimiter getVerificationCaptchaLimiter() {
return verificationCaptchaLimiter;
}
public RateLimiter getVerificationPushChallengeLimiter() {
return verificationPushChallengeLimiter;
return forDescriptor(For.VERIFY);
}
public RateLimiter getPinLimiter() {
return pinLimiter;
}
public RateLimiter getRegistrationLimiter() {
return registrationLimiter;
return forDescriptor(For.PIN);
}
public RateLimiter getTurnLimiter() {
return turnLimiter;
return forDescriptor(For.TURN);
}
public RateLimiter getProfileLimiter() {
return profileLimiter;
return forDescriptor(For.PROFILE);
}
public RateLimiter getStickerPackLimiter() {
return stickerPackLimiter;
return forDescriptor(For.STICKER_PACK);
}
public RateLimiter getArtPackLimiter() {
return artPackLimiter;
return forDescriptor(For.ART_PACK);
}
public RateLimiter getUsernameLookupLimiter() {
return byHandle(Handle.USERNAME_LOOKUP).orElseThrow();
return forDescriptor(For.USERNAME_LOOKUP);
}
public RateLimiter getUsernameSetLimiter() {
return usernameSetLimiter;
return forDescriptor(For.USERNAME_SET);
}
public RateLimiter getUsernameReserveLimiter() {
return usernameReserveLimiter;
return forDescriptor(For.USERNAME_RESERVE);
}
public RateLimiter getCheckAccountExistenceLimiter() {
return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow();
return forDescriptor(For.CHECK_ACCOUNT_EXISTENCE);
}
private static RateLimiter fromConfig(
final String name,
final RateLimitsConfiguration.RateLimitConfiguration cfg,
final FaultTolerantRedisCluster cacheCluster) {
return new RateLimiter(cacheCluster, name, cfg.getBucketSize(), cfg.getLeakRatePerMinute());
public RateLimiter getStoriesLimiter() {
return forDescriptor(For.STORIES);
}
public RateLimiter getRegistrationLimiter() {
return forDescriptor(For.REGISTRATION);
}
public RateLimiter getRateLimitResetLimiter() {
return forDescriptor(For.RATE_LIMIT_RESET);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_ATTEMPT);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_SUCCESS);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_ATTEMPT);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_SUCCESS);
}
public RateLimiter getVerificationPushChallengeLimiter() {
return forDescriptor(For.VERIFICATION_PUSH_CHALLENGE);
}
public RateLimiter getVerificationCaptchaLimiter() {
return forDescriptor(For.VERIFICATION_CAPTCHA);
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import static java.util.Objects.requireNonNull;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class StaticRateLimiter implements RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
private static final ObjectMapper MAPPER = SystemMapper.getMapper();
protected final String name;
private final RateLimiterConfig config;
protected final FaultTolerantRedisCluster cacheCluster;
private final Meter meter;
private final Timer validateTimer;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.cacheCluster = requireNonNull(cacheCluster);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
final LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
@Override
public RateLimiterConfig config() {
return config;
}
private void setBucket(final String key, final LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(MAPPER);
cacheCluster.useCluster(connection -> connection.sync().setex(
getBucketName(key),
(int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
serialized));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(final String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(MAPPER, serialized);
}
} catch (final IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
}
private String getBucketName(final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}