Remove X-Forwarded-For from RemoteAddressFilter

This commit is contained in:
Chris Eager
2024-04-10 17:40:55 -05:00
committed by Chris Eager
parent 39fd955f13
commit 05a92494bb
10 changed files with 59 additions and 161 deletions

View File

@@ -3,8 +3,6 @@ package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.core.Application;
@@ -77,10 +75,10 @@ public class WebsocketResourceProviderIntegrationTest {
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedAccount.class), PrincipalSupplier.forImmutablePrincipal()));

View File

@@ -95,10 +95,10 @@ public class WebsocketReuseAuthIntegrationTest {
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);

View File

@@ -118,9 +118,9 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
environment.jersey()

View File

@@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import com.google.common.net.HttpHeaders;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
@@ -39,7 +38,6 @@ import javax.ws.rs.core.Context;
import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
@@ -47,7 +45,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@@ -62,7 +59,6 @@ class RemoteAddressFilterIntegrationTest {
private static final String WEBSOCKET_PREFIX = "/websocket";
private static final String REMOTE_ADDRESS_PATH = "/remoteAddress";
private static final String FORWARDED_FOR_PATH = "/forwardedFor";
private static final String WS_REQUEST_PATH = "/wsRequest";
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
@@ -92,22 +88,6 @@ class RemoteAddressFilterIntegrationTest {
assertEquals(ip, response.remoteAddress());
}
@ParameterizedTest
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
void testForwardedFor(String forwardedFor, String expectedIp) {
Client client = EXTENSION.client();
final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH))
.request("application/json")
.header(HttpHeaders.X_FORWARDED_FOR, forwardedFor)
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);
assertEquals(expectedIp, response.remoteAddress());
}
}
@Nested
@@ -149,28 +129,6 @@ class RemoteAddressFilterIntegrationTest {
assertEquals(ip, response.remoteAddress());
}
@ParameterizedTest
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
void testForwardedFor(String forwardedFor, String expectedIp) throws Exception {
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor);
final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();
client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture),
URI.create(
String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)),
upgradeRequest);
final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
assertEquals(expectedIp, response.remoteAddress());
}
}
private static class ClientEndpoint implements WebSocketListener {
@@ -233,11 +191,6 @@ class RemoteAddressFilterIntegrationTest {
}
@Path(FORWARDED_FOR_PATH)
public static class TestForwardedForController extends TestController {
}
@Path(WS_REQUEST_PATH)
public static class TestWebSocketController extends TestController {
@@ -253,17 +206,11 @@ class RemoteAddressFilterIntegrationTest {
public void run(final Configuration configuration,
final Environment environment) throws Exception {
// 2 filters, to cover useRemoteAddress = {true, false}
// each has explicit (not wildcard) path matching
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true))
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH,
WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
environment.jersey().register(new TestRemoteAddressController());
environment.jersey().register(new TestForwardedForController());
// WebSocket set up
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
@@ -279,9 +226,6 @@ class RemoteAddressFilterIntegrationTest {
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
// 2 servlets, because the filter only runs for the Upgrade request
environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);

View File

@@ -5,24 +5,17 @@
package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import java.util.Optional;
import java.util.stream.Stream;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
class RemoteAddressFilterTest {
@@ -36,7 +29,7 @@ class RemoteAddressFilterTest {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr);
final RemoteAddressFilter filter = new RemoteAddressFilter(true);
final RemoteAddressFilter filter = new RemoteAddressFilter();
final FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
@@ -45,41 +38,4 @@ class RemoteAddressFilterTest {
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
}
@ParameterizedTest
@CsvSource(value = {
"192.168.1.1, 127.0.0.1 \t 127.0.0.1",
"192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1"
}, delimiterString = "\t")
void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor);
final RemoteAddressFilter filter = new RemoteAddressFilter(false);
final FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource("argumentsForGetMostRecentProxy")
void getMostRecentProxy(final String forwardedFor, final Optional<String> expectedMostRecentProxy) {
assertEquals(expectedMostRecentProxy, RemoteAddressFilter.getMostRecentProxy(forwardedFor));
}
private static Stream<Arguments> argumentsForGetMostRecentProxy() {
return Stream.of(
arguments(null, Optional.empty()),
arguments("", Optional.empty()),
arguments(" ", Optional.empty()),
arguments("203.0.113.195,", Optional.empty()),
arguments("203.0.113.195, ", Optional.empty()),
arguments("203.0.113.195", Optional.of("203.0.113.195")),
arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178"))
);
}
}

View File

@@ -4,17 +4,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
@@ -22,7 +21,8 @@ import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
@@ -32,8 +32,6 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@@ -199,4 +197,23 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
null)
);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource("argumentsForGetMostRecentProxy")
void getMostRecentProxy(final String forwardedFor, final Optional<String> expectedMostRecentProxy) {
assertEquals(expectedMostRecentProxy, WebsocketHandshakeCompleteHandler.getMostRecentProxy(forwardedFor));
}
private static Stream<Arguments> argumentsForGetMostRecentProxy() {
return Stream.of(
arguments(null, Optional.empty()),
arguments("", Optional.empty()),
arguments(" ", Optional.empty()),
arguments("203.0.113.195,", Optional.empty()),
arguments("203.0.113.195, ", Optional.empty()),
arguments("203.0.113.195", Optional.of("203.0.113.195")),
arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178"))
);
}
}

View File

@@ -246,7 +246,7 @@ class MetricsHttpChannelListenerIntegrationTest {
metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE));
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new TestResource());