mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-22 00:18:06 +01:00
Refactor remote address/X-Forwarded-For handling
This commit is contained in:
@@ -69,6 +69,7 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
@@ -306,9 +307,9 @@ class AuthEnablementRefreshRequirementProviderTest {
|
||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
|
||||
provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
||||
requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(),
|
||||
Optional.empty(), Duration.ofMillis(30000));
|
||||
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
|
||||
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
Session session = mock(Session.class);
|
||||
|
||||
@@ -91,6 +91,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
@@ -119,9 +120,6 @@ class AccountControllerTest {
|
||||
private static final UUID SENDER_REG_LOCK_UUID = UUID.randomUUID();
|
||||
private static final UUID SENDER_TRANSFER_UUID = UUID.randomUUID();
|
||||
|
||||
private static final String NICE_HOST = "127.0.0.1";
|
||||
private static final String RATE_LIMITED_IP_HOST = "10.0.0.1";
|
||||
|
||||
private static AccountsManager accountsManager = mock(AccountsManager.class);
|
||||
private static RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||
private static RateLimiter rateLimiter = mock(RateLimiter.class);
|
||||
@@ -140,6 +138,9 @@ class AccountControllerTest {
|
||||
|
||||
private byte[] registration_lock_key = new byte[32];
|
||||
|
||||
private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
|
||||
= new TestRemoteAddressFilterProvider("127.0.0.1");
|
||||
|
||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||
.addProvider(AuthHelper.getAuthFilter())
|
||||
@@ -148,7 +149,8 @@ class AccountControllerTest {
|
||||
.addProvider(new RateLimitExceededExceptionMapper())
|
||||
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
|
||||
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
|
||||
.addProvider(new RateLimitByIpFilter(rateLimiters, true))
|
||||
.addProvider(TEST_REMOTE_ADDRESS_FILTER_PROVIDER)
|
||||
.addProvider(new RateLimitByIpFilter(rateLimiters))
|
||||
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
|
||||
.addProvider(SenderOverrideProvider.SenderOverrideFeature.class)
|
||||
.setMapper(SystemMapper.jsonMapper())
|
||||
|
||||
@@ -43,18 +43,16 @@ import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
|
||||
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
|
||||
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
|
||||
import org.whispersystems.textsecuregcm.spam.SenderOverride;
|
||||
import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class ChallengeControllerTest {
|
||||
|
||||
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
|
||||
|
||||
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager,
|
||||
true);
|
||||
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
|
||||
|
||||
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
|
||||
|
||||
@@ -73,6 +71,7 @@ class ChallengeControllerTest {
|
||||
return true;
|
||||
}
|
||||
})
|
||||
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
|
||||
.setMapper(SystemMapper.jsonMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(new RateLimitExceededExceptionMapper())
|
||||
|
||||
@@ -118,7 +118,7 @@ class VerificationControllerTest {
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(
|
||||
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
|
||||
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, true,
|
||||
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager,
|
||||
dynamicConfigurationManager, clock))
|
||||
.build();
|
||||
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.filters;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||
|
||||
import com.google.common.net.HttpHeaders;
|
||||
import io.dropwizard.core.Application;
|
||||
import io.dropwizard.core.Configuration;
|
||||
import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.net.URI;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.Principal;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.security.auth.Subject;
|
||||
import javax.servlet.DispatcherType;
|
||||
import javax.ws.rs.GET;
|
||||
import javax.ws.rs.Path;
|
||||
import javax.ws.rs.client.Client;
|
||||
import javax.ws.rs.container.ContainerRequestContext;
|
||||
import javax.ws.rs.core.Context;
|
||||
import org.eclipse.jetty.util.HostPort;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessage;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class RemoteAddressFilterIntegrationTest {
|
||||
|
||||
private static final String WEBSOCKET_PREFIX = "/websocket";
|
||||
private static final String REMOTE_ADDRESS_PATH = "/remoteAddress";
|
||||
private static final String FORWARDED_FOR_PATH = "/forwardedFor";
|
||||
private static final String WS_REQUEST_PATH = "/wsRequest";
|
||||
|
||||
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
|
||||
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
|
||||
// full Jetty server in a separate process
|
||||
private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(
|
||||
TestApplication.class);
|
||||
|
||||
@Nested
|
||||
class Rest {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||
void testRemoteAddress(String ip) throws Exception {
|
||||
final Set<String> addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
|
||||
.map(InetAddress::getHostAddress)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
|
||||
|
||||
Client client = EXTENSION.client();
|
||||
|
||||
final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
|
||||
String.format("http://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(), REMOTE_ADDRESS_PATH))
|
||||
.request("application/json")
|
||||
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);
|
||||
|
||||
assertEquals(ip, response.remoteAddress());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
|
||||
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
|
||||
void testForwardedFor(String forwardedFor, String expectedIp) {
|
||||
|
||||
Client client = EXTENSION.client();
|
||||
|
||||
final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
|
||||
String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH))
|
||||
.request("application/json")
|
||||
.header(HttpHeaders.X_FORWARDED_FOR, forwardedFor)
|
||||
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);
|
||||
|
||||
assertEquals(expectedIp, response.remoteAddress());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
class WebSocket {
|
||||
|
||||
private WebSocketClient client;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
client = new WebSocketClient();
|
||||
client.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
client.stop();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||
void testRemoteAddress(String ip) throws Exception {
|
||||
final Set<String> addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
|
||||
.map(InetAddress::getHostAddress)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
|
||||
|
||||
final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();
|
||||
final ClientEndpoint clientEndpoint = new ClientEndpoint(WS_REQUEST_PATH, responseFuture);
|
||||
|
||||
client.connect(clientEndpoint,
|
||||
URI.create(
|
||||
String.format("ws://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(),
|
||||
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH)));
|
||||
|
||||
final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
|
||||
|
||||
final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
|
||||
|
||||
assertEquals(ip, response.remoteAddress());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
|
||||
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
|
||||
void testForwardedFor(String forwardedFor, String expectedIp) throws Exception {
|
||||
|
||||
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||
upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor);
|
||||
|
||||
final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();
|
||||
|
||||
client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture),
|
||||
URI.create(
|
||||
String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)),
|
||||
upgradeRequest);
|
||||
|
||||
final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
|
||||
|
||||
final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
|
||||
|
||||
assertEquals(expectedIp, response.remoteAddress());
|
||||
}
|
||||
}
|
||||
|
||||
private static class ClientEndpoint implements WebSocketListener {
|
||||
|
||||
private final String requestPath;
|
||||
private final CompletableFuture<byte[]> responseFuture;
|
||||
private final WebSocketMessageFactory messageFactory;
|
||||
|
||||
ClientEndpoint(String requestPath, CompletableFuture<byte[]> responseFuture) {
|
||||
|
||||
this.requestPath = requestPath;
|
||||
this.responseFuture = responseFuture;
|
||||
this.messageFactory = new ProtobufWebSocketMessageFactory();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketConnect(final Session session) {
|
||||
final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
|
||||
List.of("Accept: application/json"),
|
||||
Optional.empty()).toByteArray();
|
||||
try {
|
||||
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
|
||||
|
||||
try {
|
||||
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
|
||||
|
||||
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
|
||||
assert 200 == webSocketMessage.getResponseMessage().getStatus();
|
||||
responseFuture.complete(webSocketMessage.getResponseMessage().getBody().orElseThrow());
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static abstract class TestController {
|
||||
|
||||
@GET
|
||||
public RemoteAddressFilterIntegrationTest.TestResponse get(@Context ContainerRequestContext context) {
|
||||
|
||||
return new RemoteAddressFilterIntegrationTest.TestResponse(
|
||||
(String) context.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
|
||||
}
|
||||
}
|
||||
|
||||
@Path(REMOTE_ADDRESS_PATH)
|
||||
public static class TestRemoteAddressController extends TestController {
|
||||
|
||||
}
|
||||
|
||||
@Path(FORWARDED_FOR_PATH)
|
||||
public static class TestForwardedForController extends TestController {
|
||||
|
||||
}
|
||||
|
||||
@Path(WS_REQUEST_PATH)
|
||||
public static class TestWebSocketController extends TestController {
|
||||
|
||||
}
|
||||
|
||||
public record TestResponse(String remoteAddress) {
|
||||
|
||||
}
|
||||
|
||||
public static class TestApplication extends Application<Configuration> {
|
||||
|
||||
@Override
|
||||
public void run(final Configuration configuration,
|
||||
final Environment environment) throws Exception {
|
||||
|
||||
// 2 filters, to cover useRemoteAddress = {true, false}
|
||||
// each has explicit (not wildcard) path matching
|
||||
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true))
|
||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
|
||||
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
|
||||
environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false))
|
||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH,
|
||||
WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
|
||||
|
||||
environment.jersey().register(new TestRemoteAddressController());
|
||||
environment.jersey().register(new TestForwardedForController());
|
||||
|
||||
// WebSocket set up
|
||||
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
|
||||
|
||||
WebSocketEnvironment<TestPrincipal> webSocketEnvironment = new WebSocketEnvironment<>(environment,
|
||||
webSocketConfiguration, Duration.ofMillis(1000));
|
||||
|
||||
webSocketEnvironment.jersey().register(new TestWebSocketController());
|
||||
|
||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||
|
||||
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
|
||||
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||
|
||||
// 2 servlets, because the filter only runs for the Upgrade request
|
||||
environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet)
|
||||
.addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
|
||||
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
|
||||
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A minimal {@code Principal} implementation, only used to satisfy constructors
|
||||
*/
|
||||
public static class TestPrincipal implements Principal {
|
||||
|
||||
// Principal implementation
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean implies(final Subject subject) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.filters;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.net.HttpHeaders;
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletResponse;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
class RemoteAddressFilterTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"127.0.0.1, 127.0.0.1",
|
||||
"0:0:0:0:0:0:0:1, 0:0:0:0:0:0:0:1",
|
||||
"[0:0:0:0:0:0:0:1], 0:0:0:0:0:0:0:1"
|
||||
})
|
||||
void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAddr) throws Exception {
|
||||
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||
when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr);
|
||||
|
||||
final RemoteAddressFilter filter = new RemoteAddressFilter(true);
|
||||
|
||||
final FilterChain filterChain = mock(FilterChain.class);
|
||||
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
|
||||
|
||||
verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
|
||||
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource(value = {
|
||||
"192.168.1.1, 127.0.0.1 \t 127.0.0.1",
|
||||
"192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1"
|
||||
}, delimiterString = "\t")
|
||||
void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception {
|
||||
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||
when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor);
|
||||
|
||||
final RemoteAddressFilter filter = new RemoteAddressFilter(false);
|
||||
|
||||
final FilterChain filterChain = mock(FilterChain.class);
|
||||
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
|
||||
|
||||
verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
|
||||
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
public class RateLimitedByIpTest {
|
||||
@@ -60,7 +61,8 @@ public class RateLimitedByIpTest {
|
||||
.setMapper(SystemMapper.jsonMapper())
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addResource(new Controller())
|
||||
.addProvider(new RateLimitByIpFilter(RATE_LIMITERS, true))
|
||||
.addProvider(new RateLimitByIpFilter(RATE_LIMITERS))
|
||||
.addProvider(new TestRemoteAddressFilterProvider(IP))
|
||||
.build();
|
||||
|
||||
@Test
|
||||
|
||||
@@ -49,6 +49,7 @@ import org.glassfish.jersey.uri.UriTemplate;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
@@ -138,12 +139,8 @@ class MetricsRequestEventListenerTest {
|
||||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
applicationHandler,
|
||||
requestLog,
|
||||
new TestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(),
|
||||
Optional.empty(),
|
||||
Duration.ofMillis(30000));
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
final Session session = mock(Session.class);
|
||||
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
@@ -204,9 +201,8 @@ class MetricsRequestEventListenerTest {
|
||||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
applicationHandler,
|
||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
||||
Duration.ofMillis(30000));
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
final Session session = mock(Session.class);
|
||||
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
|
||||
@@ -35,8 +35,7 @@ class HttpServletRequestUtilIntegrationTest {
|
||||
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
|
||||
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
|
||||
// full Jetty server in a separate process
|
||||
private final DropwizardAppExtension<TestConfiguration> EXTENSION = new DropwizardAppExtension<>(
|
||||
TestApplication.class);
|
||||
private final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(TestApplication.class);
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||
@@ -72,13 +71,11 @@ class HttpServletRequestUtilIntegrationTest {
|
||||
|
||||
}
|
||||
|
||||
public static class TestApplication extends Application<TestConfiguration> {
|
||||
public static class TestApplication extends Application<Configuration> {
|
||||
|
||||
@Override
|
||||
public void run(final TestConfiguration configuration, final Environment environment) throws Exception {
|
||||
public void run(final Configuration configuration, final Environment environment) throws Exception {
|
||||
environment.jersey().register(new TestController());
|
||||
}
|
||||
}
|
||||
|
||||
public static class TestConfiguration extends Configuration {}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import javax.annotation.Priority;
|
||||
import javax.ws.rs.container.ContainerRequestContext;
|
||||
import javax.ws.rs.container.ContainerRequestFilter;
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Adds the request property set by {@link RemoteAddressFilter} for test scenarios that depend on it, but do not have
|
||||
* access to a full {@code HttpServletRequest} pipline
|
||||
*/
|
||||
@Priority(Integer.MIN_VALUE) // highest priority, since other filters might depend on it
|
||||
public class TestRemoteAddressFilterProvider implements ContainerRequestFilter {
|
||||
|
||||
private final String ip;
|
||||
|
||||
public TestRemoteAddressFilterProvider(String ip) {
|
||||
this.ip = ip;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void filter(final ContainerRequestContext requestContext) throws IOException {
|
||||
requestContext.setProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, ip);
|
||||
}
|
||||
}
|
||||
@@ -55,6 +55,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.slf4j.Logger;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
@@ -173,9 +174,9 @@ class LoggingUnhandledExceptionMapperTest {
|
||||
|
||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
||||
Duration.ofMillis(30000));
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
doAnswer(answer -> {
|
||||
|
||||
Reference in New Issue
Block a user