Refactor remote address/X-Forwarded-For handling

This commit is contained in:
Chris Eager
2024-02-02 12:53:02 -06:00
committed by Chris Eager
parent 4475d65780
commit 2ab14ca59e
22 changed files with 599 additions and 146 deletions

View File

@@ -65,6 +65,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final WebsocketRequestLog requestLog;
private final Duration idleTimeout;
private final String remoteAddress;
private final String remoteAddressPropertyName;
private Session session;
private RemoteEndpoint remoteEndpoint;
@@ -73,6 +74,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
public WebSocketResourceProvider(String remoteAddress,
String remoteAddressPropertyName,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
T authenticated,
@@ -80,6 +82,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
Optional<WebSocketConnectListener> connectListener,
Duration idleTimeout) {
this.remoteAddress = remoteAddress;
this.remoteAddressPropertyName = remoteAddressPropertyName;
this.jerseyHandler = jerseyHandler;
this.requestLog = requestLog;
this.authenticated = authenticated;
@@ -169,6 +172,8 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
}
containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
containerRequest, responseBody);

View File

@@ -6,13 +6,12 @@ package org.whispersystems.websocket;
import static java.util.Optional.ofNullable;
import com.google.common.net.HttpHeaders;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.Arrays;
import java.util.Optional;
import javax.ws.rs.InternalServerErrorException;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
@@ -38,8 +37,10 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
private final String remoteAddressPropertyName;
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
WebSocketConfiguration configuration) {
WebSocketConfiguration configuration, String remoteAddressPropertyName) {
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@@ -49,6 +50,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
this.remoteAddressPropertyName = remoteAddressPropertyName;
}
@Override
@@ -69,6 +71,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
}
return new WebSocketResourceProvider<>(getRemoteAddress(request),
remoteAddressPropertyName,
this.jerseyApplicationHandler,
this.environment.getRequestLog(),
authenticated,
@@ -93,18 +96,11 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
}
private String getRemoteAddress(JettyServerUpgradeRequest request) {
String forwardedFor = request.getHeader(HttpHeaders.X_FORWARDED_FOR);
if (forwardedFor == null || forwardedFor.isBlank()) {
if (request.getRemoteSocketAddress() instanceof InetSocketAddress inetSocketAddress) {
return inetSocketAddress.getAddress().getHostAddress();
}
return null;
} else {
return Arrays.stream(forwardedFor.split(","))
.map(String::trim)
.reduce((a, b) -> b)
.orElseThrow();
final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName);
if (StringUtils.isBlank(remoteAddress)) {
logger.error("Remote address property is not present");
throw new InternalServerErrorException();
}
return remoteAddress;
}
}

View File

@@ -18,8 +18,8 @@ import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
@@ -33,6 +33,8 @@ import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest {
private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.websocket.test.remoteAddress";
private ResourceConfig jerseyEnvironment;
private WebSocketEnvironment<Account> environment;
private WebSocketAuthenticator<Account> authenticator;
@@ -59,7 +61,7 @@ public class WebSocketResourceProviderFactoryTest {
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class));
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@@ -69,24 +71,25 @@ public class WebSocketResourceProviderFactoryTest {
@Test
void testValidAuthorization() throws AuthenticationException {
Session session = mock(Session.class);
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class));
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class));
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection);
verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request));
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(session);
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(mock(Session.class));
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
@@ -100,7 +103,8 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class));
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@@ -115,7 +119,8 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class));
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME);
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));

View File

@@ -70,12 +70,15 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
class WebSocketResourceProviderTest {
private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.weboscket.test.remoteAddress";
@Test
void testOnConnect() {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog,
new TestPrincipal("fooz"),
new ProtobufWebSocketMessageFactory(),
@@ -104,9 +107,9 @@ class WebSocketResourceProviderTest {
void testMockedRouteMessageSuccess() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -180,9 +183,9 @@ class WebSocketResourceProviderTest {
void testMockedRouteMessageFailure() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -236,9 +239,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -276,9 +279,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -316,9 +319,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -356,8 +359,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -394,9 +398,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -434,8 +438,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -473,9 +478,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -514,9 +519,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -556,9 +561,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -596,9 +601,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
Duration.ofMillis(30000));
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);