/v1/backup/auth/check endpoint added

This commit is contained in:
Sergey Skrobotov
2023-01-30 15:34:28 -08:00
parent 896e65545e
commit dc8f62a4ad
24 changed files with 1334 additions and 197 deletions

View File

@@ -0,0 +1,91 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
@VisibleForTesting
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1));
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
private final RateLimiters rateLimiters;
public RateLimitByIpFilter(final RateLimiters rateLimiters) {
this.rateLimiters = requireNonNull(rateLimiters);
}
@Override
public void filter(final ContainerRequestContext requestContext) throws IOException {
// requestContext.getUriInfo() should always be an instance of `ExtendedUriInfo`
// in the Jersey client
if (!(requestContext.getUriInfo() instanceof final ExtendedUriInfo uriInfo)) {
return;
}
final RateLimitedByIp annotation = uriInfo.getMatchedResourceMethod()
.getInvocable()
.getHandlingMethod()
.getAnnotation(RateLimitedByIp.class);
if (annotation == null) {
return;
}
final RateLimiters.Handle handle = annotation.value();
try {
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
final Optional<String> maybeMostRecentProxy = Optional.ofNullable(xffHeader)
.flatMap(HeaderUtils::getMostRecentProxy);
// checking if we failed to extract the most recent IP from the X-Forwarded-For header
// for any reason
if (maybeMostRecentProxy.isEmpty()) {
// checking if annotation is configured to fail when the most recent IP is not resolved
if (annotation.failOnUnresolvedIp()) {
logger.error("Missing/bad X-Forwarded-For: {}", xffHeader);
throw INVALID_HEADER_EXCEPTION;
}
// otherwise, allow request
return;
}
final Optional<RateLimiter> maybeRateLimiter = rateLimiters.byHandle(handle);
if (maybeRateLimiter.isEmpty()) {
logger.warn("RateLimiter not found for {}. Make sure it's initialized in RateLimiters class", handle);
return;
}
maybeRateLimiter.get().validate(maybeMostRecentProxy.get());
} catch (RateLimitExceededException e) {
final Response response = EXCEPTION_MAPPER.toResponse(e);
throw new ClientErrorException(response);
}
}
}

View File

@@ -0,0 +1,20 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimitedByIp {
RateLimiters.Handle value();
boolean failOnUnresolvedIp() default true;
}

View File

@@ -1,15 +1,41 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class RateLimiters {
public enum Handle {
USERNAME_LOOKUP("usernameLookup"),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence"),
BACKUP_AUTH_CHECK;
private final String id;
Handle(final String id) {
this.id = id;
}
Handle() {
this.id = name();
}
public String id() {
return id;
}
}
private final RateLimiter smsDestinationLimiter;
private final RateLimiter voiceDestinationLimiter;
private final RateLimiter voiceDestinationDailyLimiter;
@@ -17,114 +43,51 @@ public class RateLimiters {
private final RateLimiter smsVoicePrefixLimiter;
private final RateLimiter verifyLimiter;
private final RateLimiter pinLimiter;
private final RateLimiter attachmentLimiter;
private final RateLimiter preKeysLimiter;
private final RateLimiter messagesLimiter;
private final RateLimiter allocateDeviceLimiter;
private final RateLimiter verifyDeviceLimiter;
private final RateLimiter turnLimiter;
private final RateLimiter profileLimiter;
private final RateLimiter stickerPackLimiter;
private final RateLimiter artPackLimiter;
private final RateLimiter usernameLookupLimiter;
private final RateLimiter usernameSetLimiter;
private final RateLimiter usernameReserveLimiter;
private final RateLimiter checkAccountExistenceLimiter;
private final RateLimiter storiesLimiter;
public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination",
config.getSmsDestination().getBucketSize(),
config.getSmsDestination().getLeakRatePerMinute());
private final Map<String, RateLimiter> rateLimiterByHandle;
this.voiceDestinationLimiter = new RateLimiter(cacheCluster, "voxDestination",
config.getVoiceDestination().getBucketSize(),
config.getVoiceDestination().getLeakRatePerMinute());
public RateLimiters(final RateLimitsConfiguration config, final FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = fromConfig("smsDestination", config.getSmsDestination(), cacheCluster);
this.voiceDestinationLimiter = fromConfig("voxDestination", config.getVoiceDestination(), cacheCluster);
this.voiceDestinationDailyLimiter = fromConfig("voxDestinationDaily", config.getVoiceDestinationDaily(), cacheCluster);
this.smsVoiceIpLimiter = fromConfig("smsVoiceIp", config.getSmsVoiceIp(), cacheCluster);
this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster);
this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster);
this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster);
this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster);
this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster);
this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster);
this.allocateDeviceLimiter = fromConfig("allocateDevice", config.getAllocateDevice(), cacheCluster);
this.verifyDeviceLimiter = fromConfig("verifyDevice", config.getVerifyDevice(), cacheCluster);
this.turnLimiter = fromConfig("turnAllocate", config.getTurnAllocations(), cacheCluster);
this.profileLimiter = fromConfig("profile", config.getProfile(), cacheCluster);
this.stickerPackLimiter = fromConfig("stickerPack", config.getStickerPack(), cacheCluster);
this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster);
this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster);
this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster);
this.storiesLimiter = fromConfig("stories", config.getStories(), cacheCluster);
this.voiceDestinationDailyLimiter = new RateLimiter(cacheCluster, "voxDestinationDaily",
config.getVoiceDestinationDaily().getBucketSize(),
config.getVoiceDestinationDaily().getLeakRatePerMinute());
this.rateLimiterByHandle = Stream.of(
fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster),
fromConfig(Handle.CHECK_ACCOUNT_EXISTENCE.id(), config.getCheckAccountExistence(), cacheCluster),
fromConfig(Handle.USERNAME_LOOKUP.id(), config.getUsernameLookup(), cacheCluster)
).map(rl -> Pair.of(rl.name, rl)).collect(Collectors.toMap(Pair::getKey, Pair::getValue));
}
this.smsVoiceIpLimiter = new RateLimiter(cacheCluster, "smsVoiceIp",
config.getSmsVoiceIp().getBucketSize(),
config.getSmsVoiceIp().getLeakRatePerMinute());
this.smsVoicePrefixLimiter = new RateLimiter(cacheCluster, "smsVoicePrefix",
config.getSmsVoicePrefix().getBucketSize(),
config.getSmsVoicePrefix().getLeakRatePerMinute());
this.verifyLimiter = new LockingRateLimiter(cacheCluster, "verify",
config.getVerifyNumber().getBucketSize(),
config.getVerifyNumber().getLeakRatePerMinute());
this.pinLimiter = new LockingRateLimiter(cacheCluster, "pin",
config.getVerifyPin().getBucketSize(),
config.getVerifyPin().getLeakRatePerMinute());
this.attachmentLimiter = new RateLimiter(cacheCluster, "attachmentCreate",
config.getAttachments().getBucketSize(),
config.getAttachments().getLeakRatePerMinute());
this.preKeysLimiter = new RateLimiter(cacheCluster, "prekeys",
config.getPreKeys().getBucketSize(),
config.getPreKeys().getLeakRatePerMinute());
this.messagesLimiter = new RateLimiter(cacheCluster, "messages",
config.getMessages().getBucketSize(),
config.getMessages().getLeakRatePerMinute());
this.allocateDeviceLimiter = new RateLimiter(cacheCluster, "allocateDevice",
config.getAllocateDevice().getBucketSize(),
config.getAllocateDevice().getLeakRatePerMinute());
this.verifyDeviceLimiter = new RateLimiter(cacheCluster, "verifyDevice",
config.getVerifyDevice().getBucketSize(),
config.getVerifyDevice().getLeakRatePerMinute());
this.turnLimiter = new RateLimiter(cacheCluster, "turnAllocate",
config.getTurnAllocations().getBucketSize(),
config.getTurnAllocations().getLeakRatePerMinute());
this.profileLimiter = new RateLimiter(cacheCluster, "profile",
config.getProfile().getBucketSize(),
config.getProfile().getLeakRatePerMinute());
this.stickerPackLimiter = new RateLimiter(cacheCluster, "stickerPack",
config.getStickerPack().getBucketSize(),
config.getStickerPack().getLeakRatePerMinute());
this.artPackLimiter = new RateLimiter(cacheCluster, "artPack",
config.getArtPack().getBucketSize(),
config.getArtPack().getLeakRatePerMinute());
this.usernameLookupLimiter = new RateLimiter(cacheCluster, "usernameLookup",
config.getUsernameLookup().getBucketSize(),
config.getUsernameLookup().getLeakRatePerMinute());
this.usernameSetLimiter = new RateLimiter(cacheCluster, "usernameSet",
config.getUsernameSet().getBucketSize(),
config.getUsernameSet().getLeakRatePerMinute());
this.usernameReserveLimiter = new RateLimiter(cacheCluster, "usernameReserve",
config.getUsernameReserve().getBucketSize(),
config.getUsernameReserve().getLeakRatePerMinute());
this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence",
config.getCheckAccountExistence().getBucketSize(),
config.getCheckAccountExistence().getLeakRatePerMinute());
this.storiesLimiter = new RateLimiter(cacheCluster, "stories",
config.getStories().getBucketSize(),
config.getStories().getLeakRatePerMinute());
public Optional<RateLimiter> byHandle(final Handle handle) {
return Optional.ofNullable(rateLimiterByHandle.get(handle.id()));
}
public RateLimiter getAllocateDeviceLimiter() {
@@ -192,7 +155,7 @@ public class RateLimiters {
}
public RateLimiter getUsernameLookupLimiter() {
return usernameLookupLimiter;
return byHandle(Handle.USERNAME_LOOKUP).orElseThrow();
}
public RateLimiter getUsernameSetLimiter() {
@@ -204,8 +167,17 @@ public class RateLimiters {
}
public RateLimiter getCheckAccountExistenceLimiter() {
return checkAccountExistenceLimiter;
return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow();
}
public RateLimiter getStoriesLimiter() { return storiesLimiter; }
public RateLimiter getStoriesLimiter() {
return storiesLimiter;
}
private static RateLimiter fromConfig(
final String name,
final RateLimitsConfiguration.RateLimitConfiguration cfg,
final FaultTolerantRedisCluster cacheCluster) {
return new RateLimiter(cacheCluster, name, cfg.getBucketSize(), cfg.getLeakRatePerMinute());
}
}