Copy headers from the initial websocket upgrade request into subsequent resource requests.

This commit is contained in:
Jon Chambers
2021-03-09 14:18:21 -05:00
committed by Jon Chambers
parent 933dd81d82
commit 3cdc58200a
3 changed files with 71 additions and 12 deletions

View File

@@ -37,6 +37,7 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@@ -59,7 +60,8 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private Session session;
private RemoteEndpoint remoteEndpoint;
private WebSocketSessionContext context;
private String userAgent;
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
public WebSocketResourceProvider(String remoteAddress,
ApplicationHandler jerseyHandler,
@@ -81,7 +83,6 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
@Override
public void onWebSocketConnect(Session session) {
this.session = session;
this.userAgent = session.getUpgradeRequest().getHeader("User-Agent");
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(authenticated);
@@ -142,16 +143,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private void handleRequest(WebSocketRequestMessage requestMessage) {
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), new MapPropertiesDelegate(new HashMap<>()), null);
for (Map.Entry<String, String> entry : requestMessage.getHeaders().entrySet()) {
containerRequest.header(entry.getKey(), entry.getValue());
}
final List<String> requestUserAgentHeader = containerRequest.getRequestHeader("User-Agent");
if ((requestUserAgentHeader == null || requestUserAgentHeader.isEmpty()) && userAgent != null) {
containerRequest.header("User-Agent", userAgent);
}
containerRequest.headers(getCombinedHeaders(session.getUpgradeRequest().getHeaders(), requestMessage.getHeaders()));
if (requestMessage.getBody().isPresent()) {
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
@@ -171,6 +163,31 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
});
}
@VisibleForTesting
static Map<String, List<String>> getCombinedHeaders(final Map<String, List<String>> upgradeRequestHeaders, final Map<String, String> requestMessageHeaders) {
final Map<String, List<String>> combinedHeaders = new HashMap<>();
upgradeRequestHeaders.entrySet().stream()
.filter(entry -> shouldIncludeUpgradeRequestHeader(entry.getKey()))
.forEach(entry -> combinedHeaders.put(entry.getKey(), entry.getValue()));
requestMessageHeaders.entrySet().stream()
.filter(entry -> shouldIncludeRequestMessageHeader(entry.getKey()))
.forEach(entry -> combinedHeaders.put(entry.getKey(), List.of(entry.getValue())));
return combinedHeaders;
}
@VisibleForTesting
static boolean shouldIncludeUpgradeRequestHeader(final String header) {
return !EXCLUDED_UPGRADE_REQUEST_HEADERS.contains(header.toLowerCase()) && !header.toLowerCase().contains("websocket-");
}
@VisibleForTesting
static boolean shouldIncludeRequestMessageHeader(final String header) {
return !"X-Forwarded-For".equalsIgnoreCase(header.trim());
}
private void handleResponse(WebSocketResponseMessage responseMessage) {
CompletableFuture<WebSocketResponseMessage> future = requestMap.remove(responseMessage.getRequestId());