Add ExternalRequestFilter

This commit is contained in:
Chris Eager
2024-01-29 16:48:56 -06:00
committed by Chris Eager
parent 63c8b275d1
commit 5b97bc04e0
10 changed files with 581 additions and 1 deletions

View File

@@ -30,6 +30,7 @@ import org.whispersystems.textsecuregcm.configuration.DirectoryV2Configuration;
import org.whispersystems.textsecuregcm.configuration.DogstatsdConfiguration;
import org.whispersystems.textsecuregcm.configuration.DynamoDbClientConfiguration;
import org.whispersystems.textsecuregcm.configuration.DynamoDbTables;
import org.whispersystems.textsecuregcm.configuration.ExternalRequestFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.FcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration;
import org.whispersystems.textsecuregcm.configuration.GenericZkConfig;
@@ -339,6 +340,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private NoiseWebSocketTunnelConfiguration noiseTunnel;
@Valid
@NotNull
@JsonProperty
private ExternalRequestFilterConfiguration externalRequestFilter;
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
return tlsKeyStore;
}
@@ -565,4 +571,8 @@ public class WhisperServerConfiguration extends Configuration {
public NoiseWebSocketTunnelConfiguration getNoiseWebSocketTunnelConfiguration() {
return noiseTunnel;
}
public ExternalRequestFilterConfiguration getExternalRequestFilterConfiguration() {
return externalRequestFilter;
}
}

View File

@@ -123,6 +123,7 @@ import org.whispersystems.textsecuregcm.currency.CoinMarketCapClient;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.currency.FixerClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
@@ -778,6 +779,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
// depends on the user-agent context so it has to come first here!
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
serverBuilder
.intercept(
new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
config.getExternalRequestFilterConfiguration().grpcMethods()))
// TODO: specialize metrics with user-agent platform
.intercept(metricCollectingServerInterceptor)
.intercept(errorMappingInterceptor)
@@ -827,6 +831,14 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) {
environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(),
new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
config.getExternalRequestFilterConfiguration().grpcMethods()))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true,
config.getExternalRequestFilterConfiguration().paths().toArray(new String[]{}));
}
final AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter =
new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
.setAuthenticator(accountAuthenticator)

View File

@@ -0,0 +1,16 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import java.util.Set;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.InetAddressRange;
public record ExternalRequestFilterConfiguration(@Valid @NotNull Set<@NotNull String> paths,
@Valid @NotNull Set<@NotNull InetAddressRange> permittedInternalRanges,
@Valid @NotNull Set<@NotNull String> grpcMethods) {
}

View File

@@ -0,0 +1,101 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesUtil;
import org.whispersystems.textsecuregcm.util.InetAddressRange;
public class ExternalRequestFilter implements Filter, ServerInterceptor {
private static final Logger logger = LoggerFactory.getLogger(ExternalRequestFilter.class);
private static final String REQUESTS_COUNTER_NAME = name(ExternalRequestFilter.class, "requests");
private static final String PROTOCOL_TAG_NAME = "protocol";
private static final String BLOCKED_TAG_NAME = "blocked";
private final Set<InetAddressRange> permittedInternalAddressRanges;
private final Set<String> filteredGrpcMethodNames;
public ExternalRequestFilter(final Set<InetAddressRange> permittedInternalAddressRanges,
final Set<String> filteredGrpcMethodNames) {
this.permittedInternalAddressRanges = permittedInternalAddressRanges;
this.filteredGrpcMethodNames = filteredGrpcMethodNames;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
final MethodDescriptor<ReqT, RespT> methodDescriptor = call.getMethodDescriptor();
final boolean shouldFilterMethod = filteredGrpcMethodNames.contains(methodDescriptor.getFullMethodName());
final InetAddress remoteAddress = RequestAttributesUtil.getRemoteAddress();
final boolean blocked = shouldFilterMethod && shouldBlock(remoteAddress);
Metrics.counter(REQUESTS_COUNTER_NAME,
PROTOCOL_TAG_NAME, "grpc",
BLOCKED_TAG_NAME, String.valueOf(blocked))
.increment();
if (blocked) {
call.close(Status.NOT_FOUND, new Metadata());
return new ServerCall.Listener<>() {};
}
return next.startCall(call, headers);
}
@Override
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
throws IOException, ServletException {
final InetAddress remoteInetAddress = InetAddress.getByName(
(String) request.getAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
final boolean restricted = shouldBlock(remoteInetAddress);
Metrics.counter(REQUESTS_COUNTER_NAME,
PROTOCOL_TAG_NAME, "http",
BLOCKED_TAG_NAME, String.valueOf(restricted))
.increment();
if (restricted) {
if (response instanceof HttpServletResponse hsr) {
hsr.setStatus(404);
} else {
logger.warn("response was an unexpected type: {}", response.getClass());
}
return;
}
chain.doFilter(request, response);
}
public boolean shouldBlock(InetAddress remoteAddress) {
return permittedInternalAddressRanges.stream()
.noneMatch(range -> range.contains(remoteAddress));
}
}

View File

@@ -0,0 +1,103 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
import java.net.InetAddress;
import java.util.Arrays;
/**
* An InetAddressRange represents a contiguous range of IPv4 or IPv6 addresses.
*/
public class InetAddressRange {
private final InetAddress networkAddress;
private final byte[] networkAddressBytes;
private final byte[] prefixMask;
public InetAddressRange(final String cidrBlock) {
final String[] components = cidrBlock.split("/");
if (components.length != 2) {
throw new IllegalArgumentException("Unexpected CIDR block notation: " + cidrBlock);
}
final int prefixLength;
try {
networkAddress = InetAddresses.forString(components[0]);
prefixLength = Integer.parseInt(components[1]);
if (prefixLength > networkAddress.getAddress().length * 8) {
throw new IllegalArgumentException("Prefix length cannot exceed length of address");
}
} catch (final NumberFormatException e) {
throw new IllegalArgumentException("Bad prefix length: " + components[1]);
}
networkAddressBytes = networkAddress.getAddress();
prefixMask = generatePrefixMask(networkAddressBytes.length, prefixLength);
}
@VisibleForTesting
static byte[] generatePrefixMask(final int addressLengthBytes, final int prefixLengthBits) {
final byte[] prefixMask = new byte[addressLengthBytes];
for (int i = 0; i < addressLengthBytes; i++) {
final int bitsAvailable = Math.min(8, Math.max(0, prefixLengthBits - (i * 8)));
prefixMask[i] = (byte) (0xff << (8 - bitsAvailable));
}
return prefixMask;
}
public boolean contains(final String name) {
// InetAddresses.forString() throws "IllegalArgumentException" for anything that is not an IP address
return contains(InetAddresses.forString(name));
}
public boolean contains(final InetAddress inetAddress) {
if (!networkAddress.getClass().equals(inetAddress.getClass())) {
return false;
}
final byte[] addressBytes = inetAddress.getAddress();
for (int i = 0; i < addressBytes.length; i++) {
if (((addressBytes[i] ^ networkAddressBytes[i]) & prefixMask[i]) != 0) {
return false;
}
}
return true;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final InetAddressRange that = (InetAddressRange) o;
if (!networkAddress.equals(that.networkAddress)) {
return false;
}
return Arrays.equals(prefixMask, that.prefixMask);
}
@Override
public int hashCode() {
int result = networkAddress.hashCode();
result = 31 * result + Arrays.hashCode(prefixMask);
return result;
}
}