Update to dropwizard 2.x

This commit is contained in:
Moxie Marlinspike
2020-03-06 17:39:31 -08:00
parent 69285f28ad
commit 009f81a9a6
45 changed files with 1782 additions and 3011 deletions

View File

@@ -17,32 +17,33 @@
package org.whispersystems.websocket;
import com.google.common.annotations.VisibleForTesting;
import org.eclipse.jetty.server.RequestLog;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.glassfish.jersey.internal.MapPropertiesDelegate;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.servlet.LoggableRequest;
import org.whispersystems.websocket.servlet.LoggableResponse;
import org.whispersystems.websocket.servlet.NullServletResponse;
import org.whispersystems.websocket.servlet.WebSocketServletRequest;
import org.whispersystems.websocket.servlet.WebSocketServletResponse;
import org.whispersystems.websocket.session.ContextPrincipal;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.Response;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -52,31 +53,34 @@ import java.util.concurrent.ConcurrentHashMap;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketResourceProvider implements WebSocketListener {
public class WebSocketResourceProvider<T extends Principal> implements WebSocketListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProvider.class);
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final Object authenticated;
private final T authenticated;
private final WebSocketMessageFactory messageFactory;
private final Optional<WebSocketConnectListener> connectListener;
private final HttpServlet servlet;
private final RequestLog requestLog;
private final ApplicationHandler jerseyHandler;
private final WebsocketRequestLog requestLog;
private final long idleTimeoutMillis;
private final String remoteAddress;
private Session session;
private RemoteEndpoint remoteEndpoint;
private WebSocketSessionContext context;
public WebSocketResourceProvider(HttpServlet servlet,
RequestLog requestLog,
Object authenticated,
public WebSocketResourceProvider(String remoteAddress,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
T authenticated,
WebSocketMessageFactory messageFactory,
Optional<WebSocketConnectListener> connectListener,
long idleTimeoutMillis)
{
this.servlet = servlet;
this.remoteAddress = remoteAddress;
this.jerseyHandler = jerseyHandler;
this.requestLog = requestLog;
this.authenticated = authenticated;
this.messageFactory = messageFactory;
@@ -131,7 +135,7 @@ public class WebSocketResourceProvider implements WebSocketListener {
context.notifyClosed(statusCode, reason);
for (long requestId : requestMap.keySet()) {
CompletableFuture outstandingRequest = requestMap.remove(requestId);
CompletableFuture<WebSocketResponseMessage> outstandingRequest = requestMap.remove(requestId);
if (outstandingRequest != null) {
outstandingRequest.completeExceptionally(new IOException("Connection closed!"));
@@ -146,17 +150,28 @@ public class WebSocketResourceProvider implements WebSocketListener {
}
private void handleRequest(WebSocketRequestMessage requestMessage) {
try {
HttpServletRequest servletRequest = createRequest(requestMessage, context);
HttpServletResponse servletResponse = createResponse(requestMessage);
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), new MapPropertiesDelegate(new HashMap<>()), null);
servlet.service(servletRequest, servletResponse);
servletResponse.flushBuffer();
requestLog.log(new LoggableRequest(servletRequest), new LoggableResponse(servletResponse));
} catch (IOException | ServletException e) {
logger.warn("Servlet Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" + requestMessage.getBody(), e);
sendErrorResponse(requestMessage, Response.status(500).build());
for (Map.Entry<String, String> entry : requestMessage.getHeaders().entrySet()) {
containerRequest.header(entry.getKey(), entry.getValue());
}
if (requestMessage.getBody().isPresent()) {
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
}
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(containerRequest, responseBody);
responseFuture.thenAccept(response -> {
sendResponse(requestMessage, response, responseBody);
requestLog.log(remoteAddress, containerRequest, response);
}).exceptionally(exception -> {
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" + requestMessage.getBody(), exception);
sendErrorResponse(requestMessage, Response.status(500).build());
requestLog.log(remoteAddress, containerRequest, new ContainerResponse(containerRequest, Response.status(500).build()));
return null;
});
}
private void handleResponse(WebSocketResponseMessage responseMessage) {
@@ -171,17 +186,22 @@ public class WebSocketResourceProvider implements WebSocketListener {
session.close(status, message);
}
private HttpServletRequest createRequest(WebSocketRequestMessage message,
WebSocketSessionContext context)
{
return new WebSocketServletRequest(context, message, servlet.getServletContext());
}
private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response, ByteArrayOutputStream responseBody) {
if (requestMessage.hasRequestId()) {
byte[] body = responseBody.toByteArray();
private HttpServletResponse createResponse(WebSocketRequestMessage message) {
if (message.hasRequestId()) {
return new WebSocketServletResponse(remoteEndpoint, message.getRequestId(), messageFactory);
} else {
return new NullServletResponse();
if (body.length <= 0) {
body = null;
}
byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(),
response.getStatus(),
response.getStatusInfo().getReasonPhrase(),
new LinkedList<>(),
Optional.ofNullable(body))
.toByteArray();
remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(responseBytes));
}
}
@@ -203,8 +223,10 @@ public class WebSocketResourceProvider implements WebSocketListener {
}
}
@VisibleForTesting
WebSocketSessionContext getContext() {
return context;
}
}

View File

@@ -16,77 +16,56 @@
*/
package org.whispersystems.websocket;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.util.AttributesMap;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.internal.WebSocketAuthValueFactoryProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.RequestDispatcher;
import javax.servlet.Servlet;
import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRegistration;
import javax.servlet.SessionCookieConfig;
import javax.servlet.SessionTrackingMode;
import javax.servlet.descriptor.JspConfigDescriptor;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.AccessController;
import java.util.Collections;
import java.util.Enumeration;
import java.util.EventListener;
import java.util.Map;
import java.security.Principal;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import static java.util.Optional.ofNullable;
public class WebSocketResourceProviderFactory extends WebSocketServlet implements WebSocketCreator {
public class WebSocketResourceProviderFactory<T extends Principal> extends WebSocketServlet implements WebSocketCreator {
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class);
private final WebSocketEnvironment environment;
private final WebSocketEnvironment<T> environment;
private final ApplicationHandler jerseyApplicationHandler;
public WebSocketResourceProviderFactory(WebSocketEnvironment environment)
throws ServletException
{
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass) {
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
environment.jersey().register(new WebSocketAuthValueFactoryProvider.Binder());
environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<T>(principalClass));
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
}
public void start() throws ServletException {
this.environment.getJerseyServletContainer().init(new WServletConfig());
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
}
@Override
public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) {
try {
Optional<WebSocketAuthenticator> authenticator = Optional.ofNullable(environment.getAuthenticator());
Object authenticated = null;
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
T authenticated = null;
if (authenticator.isPresent()) {
AuthenticationResult authenticationResult = authenticator.get().authenticate(request);
AuthenticationResult<T> authenticationResult = authenticator.get().authenticate(request);
if (!authenticationResult.getUser().isPresent() && authenticationResult.isRequired()) {
if (authenticationResult.getUser().isEmpty() && authenticationResult.isRequired()) {
response.sendForbidden("Unauthorized");
return null;
} else {
@@ -94,14 +73,18 @@ public class WebSocketResourceProviderFactory extends WebSocketServlet implement
}
}
return new WebSocketResourceProvider(this.environment.getJerseyServletContainer(),
this.environment.getRequestLog(),
authenticated,
this.environment.getMessageFactory(),
Optional.ofNullable(this.environment.getConnectListener()),
this.environment.getIdleTimeoutMillis());
return new WebSocketResourceProvider<T>(getRemoteAddress(request),
this.jerseyApplicationHandler,
this.environment.getRequestLog(),
authenticated,
this.environment.getMessageFactory(),
ofNullable(this.environment.getConnectListener()),
this.environment.getIdleTimeoutMillis());
} catch (AuthenticationException | IOException e) {
logger.warn("Authentication failure", e);
try {
response.sendError(500, "Failure");
} catch (IOException ex) {}
return null;
}
}
@@ -111,358 +94,16 @@ public class WebSocketResourceProviderFactory extends WebSocketServlet implement
factory.setCreator(this);
}
private static class WServletConfig implements ServletConfig {
private String getRemoteAddress(ServletUpgradeRequest request) {
String forwardedFor = request.getHeader("X-Forwarded-For");
private final ServletContext context = new NoContext();
@Override
public String getServletName() {
return "WebSocketResourceServlet";
}
@Override
public ServletContext getServletContext() {
return context;
}
@Override
public String getInitParameter(String name) {
return null;
}
@Override
public Enumeration<String> getInitParameterNames() {
return new Enumeration<String>() {
@Override
public boolean hasMoreElements() {
return false;
}
@Override
public String nextElement() {
return null;
}
};
if (forwardedFor == null || forwardedFor.isBlank()) {
return request.getRemoteAddress();
} else {
return Arrays.stream(forwardedFor.split(","))
.map(String::trim)
.reduce((a, b) -> b)
.orElseThrow();
}
}
public static class NoContext extends AttributesMap implements ServletContext
{
private int effectiveMajorVersion = 3;
private int effectiveMinorVersion = 0;
@Override
public ServletContext getContext(String uripath)
{
return null;
}
@Override
public int getMajorVersion()
{
return 3;
}
@Override
public String getMimeType(String file)
{
return null;
}
@Override
public int getMinorVersion()
{
return 0;
}
@Override
public RequestDispatcher getNamedDispatcher(String name)
{
return null;
}
@Override
public RequestDispatcher getRequestDispatcher(String uriInContext)
{
return null;
}
@Override
public String getRealPath(String path)
{
return null;
}
@Override
public URL getResource(String path) throws MalformedURLException
{
return null;
}
@Override
public InputStream getResourceAsStream(String path)
{
return null;
}
@Override
public Set<String> getResourcePaths(String path)
{
return null;
}
@Override
public String getServerInfo()
{
return "websocketresources/" + Server.getVersion();
}
@Override
@Deprecated
public Servlet getServlet(String name) throws ServletException
{
return null;
}
@SuppressWarnings("unchecked")
@Override
@Deprecated
public Enumeration<String> getServletNames()
{
return Collections.enumeration(Collections.EMPTY_LIST);
}
@SuppressWarnings("unchecked")
@Override
@Deprecated
public Enumeration<Servlet> getServlets()
{
return Collections.enumeration(Collections.EMPTY_LIST);
}
@Override
public void log(Exception exception, String msg)
{
logger.warn(msg,exception);
}
@Override
public void log(String msg)
{
logger.info(msg);
}
@Override
public void log(String message, Throwable throwable)
{
logger.warn(message,throwable);
}
@Override
public String getInitParameter(String name)
{
return null;
}
@SuppressWarnings("unchecked")
@Override
public Enumeration<String> getInitParameterNames()
{
return Collections.enumeration(Collections.EMPTY_LIST);
}
@Override
public String getServletContextName()
{
return "No Context";
}
@Override
public String getContextPath()
{
return null;
}
@Override
public boolean setInitParameter(String name, String value)
{
return false;
}
@Override
public FilterRegistration.Dynamic addFilter(String filterName, Class<? extends Filter> filterClass)
{
return null;
}
@Override
public FilterRegistration.Dynamic addFilter(String filterName, Filter filter)
{
return null;
}
@Override
public FilterRegistration.Dynamic addFilter(String filterName, String className)
{
return null;
}
@Override
public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, Class<? extends Servlet> servletClass)
{
return null;
}
@Override
public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet)
{
return null;
}
@Override
public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, String className)
{
return null;
}
@Override
public <T extends Filter> T createFilter(Class<T> c) throws ServletException
{
return null;
}
@Override
public <T extends Servlet> T createServlet(Class<T> c) throws ServletException
{
return null;
}
@Override
public Set<SessionTrackingMode> getDefaultSessionTrackingModes()
{
return null;
}
@Override
public Set<SessionTrackingMode> getEffectiveSessionTrackingModes()
{
return null;
}
@Override
public FilterRegistration getFilterRegistration(String filterName)
{
return null;
}
@Override
public Map<String, ? extends FilterRegistration> getFilterRegistrations()
{
return null;
}
@Override
public ServletRegistration getServletRegistration(String servletName)
{
return null;
}
@Override
public Map<String, ? extends ServletRegistration> getServletRegistrations()
{
return null;
}
@Override
public SessionCookieConfig getSessionCookieConfig()
{
return null;
}
@Override
public void setSessionTrackingModes(Set<SessionTrackingMode> sessionTrackingModes)
{
}
@Override
public void addListener(String className)
{
}
@Override
public <T extends EventListener> void addListener(T t)
{
}
@Override
public void addListener(Class<? extends EventListener> listenerClass)
{
}
@Override
public <T extends EventListener> T createListener(Class<T> clazz) throws ServletException
{
try
{
return clazz.newInstance();
}
catch (InstantiationException e)
{
throw new ServletException(e);
}
catch (IllegalAccessException e)
{
throw new ServletException(e);
}
}
@Override
public ClassLoader getClassLoader()
{
AccessController.checkPermission(new RuntimePermission("getClassLoader"));
return WebSocketResourceProviderFactory.class.getClassLoader();
}
@Override
public int getEffectiveMajorVersion()
{
return effectiveMajorVersion;
}
@Override
public int getEffectiveMinorVersion()
{
return effectiveMinorVersion;
}
public void setEffectiveMajorVersion (int v)
{
this.effectiveMajorVersion = v;
}
public void setEffectiveMinorVersion (int v)
{
this.effectiveMinorVersion = v;
}
@Override
public JspConfigDescriptor getJspConfigDescriptor()
{
return null;
}
@Override
public void declareRoles(String... roleNames)
{
}
@Override
public String getVirtualServerName() {
return null;
}
}
}

View File

@@ -0,0 +1,40 @@
package org.whispersystems.websocket;
import org.whispersystems.websocket.session.ContextPrincipal;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.ws.rs.core.SecurityContext;
import java.security.Principal;
public class WebSocketSecurityContext implements SecurityContext {
private final ContextPrincipal principal;
public WebSocketSecurityContext(ContextPrincipal principal) {
this.principal = principal;
}
@Override
public Principal getUserPrincipal() {
return (Principal)principal.getContext().getAuthenticated();
}
@Override
public boolean isUserInRole(String role) {
return false;
}
@Override
public boolean isSecure() {
return principal != null;
}
@Override
public String getAuthenticationScheme() {
return null;
}
public WebSocketSessionContext getSessionContext() {
return principal.getContext();
}
}

View File

@@ -19,9 +19,10 @@ package org.whispersystems.websocket.auth;
import org.eclipse.jetty.server.Authentication;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import java.security.Principal;
import java.util.Optional;
public interface WebSocketAuthenticator<T> {
public interface WebSocketAuthenticator<T extends Principal> {
AuthenticationResult<T> authenticate(UpgradeRequest request) throws AuthenticationException;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")

View File

@@ -0,0 +1,114 @@
package org.whispersystems.websocket.auth;
import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider;
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.ws.rs.WebApplicationException;
import java.lang.reflect.ParameterizedType;
import java.security.Principal;
import java.util.Optional;
import java.util.function.Function;
import io.dropwizard.auth.Auth;
@Singleton
public class WebsocketAuthValueFactoryProvider<T extends Principal> extends AbstractValueParamProvider {
private final Class<T> principalClass;
@Inject
public WebsocketAuthValueFactoryProvider(MultivaluedParameterExtractorProvider mpep, WebsocketPrincipalClassProvider<T> principalClassProvider) {
super(() -> mpep, Parameter.Source.UNKNOWN);
this.principalClass = principalClassProvider.clazz;
}
@Nullable
@Override
protected Function<ContainerRequest, ?> createValueProvider(Parameter parameter) {
if (!parameter.isAnnotationPresent(Auth.class)) {
return null;
}
if (parameter.getRawType() == Optional.class &&
ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) &&
principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0])
{
return request -> new OptionalContainerRequestValueFactory(request).provide();
} else if (principalClass.equals(parameter.getRawType())) {
return request -> new StandardContainerRequestValueFactory(request).provide();
} else {
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
}
}
@Singleton
static class WebsocketPrincipalClassProvider<T extends Principal> {
private final Class<T> clazz;
WebsocketPrincipalClassProvider(Class<T> clazz) {
this.clazz = clazz;
}
}
/**
* Injection binder for {@link io.dropwizard.auth.AuthValueFactoryProvider}.
*
* @param <T> the type of the principal
*/
public static class Binder<T extends Principal> extends AbstractBinder {
private final Class<T> principalClass;
public Binder(Class<T> principalClass) {
this.principalClass = principalClass;
}
@Override
protected void configure() {
bind(new WebsocketPrincipalClassProvider<>(principalClass)).to(WebsocketPrincipalClassProvider.class);
bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
}
}
private static class StandardContainerRequestValueFactory {
private final ContainerRequest request;
public StandardContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Principal provide() {
final Principal principal = request.getSecurityContext().getUserPrincipal();
if (principal == null) {
throw new WebApplicationException("Authenticated resource", 401);
}
return principal;
}
}
private static class OptionalContainerRequestValueFactory {
private final ContainerRequest request;
public OptionalContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Optional<Principal> provide() {
return Optional.ofNullable(request.getSecurityContext().getUserPrincipal());
}
}
}

View File

@@ -1,120 +0,0 @@
package org.whispersystems.websocket.auth.internal;
import org.glassfish.hk2.api.InjectionResolver;
import org.glassfish.hk2.api.ServiceLocator;
import org.glassfish.hk2.api.TypeLiteral;
import org.glassfish.hk2.utilities.binding.AbstractBinder;
import org.glassfish.jersey.server.internal.inject.AbstractContainerRequestValueFactory;
import org.glassfish.jersey.server.internal.inject.AbstractValueFactoryProvider;
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
import org.glassfish.jersey.server.internal.inject.ParamInjectionResolver;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueFactoryProvider;
import org.whispersystems.websocket.servlet.WebSocketServletRequest;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.ws.rs.WebApplicationException;
import java.security.Principal;
import java.util.Optional;
import io.dropwizard.auth.Auth;
@Singleton
public class WebSocketAuthValueFactoryProvider extends AbstractValueFactoryProvider {
@Inject
public WebSocketAuthValueFactoryProvider(MultivaluedParameterExtractorProvider mpep,
ServiceLocator injector)
{
super(mpep, injector, Parameter.Source.UNKNOWN);
}
@Override
public AbstractContainerRequestValueFactory<?> createValueFactory(final Parameter parameter) {
if (parameter.getAnnotation(Auth.class) == null) {
return null;
}
if (parameter.getRawType() == Optional.class) {
return new OptionalContainerRequestValueFactory(parameter);
} else {
return new StandardContainerRequestValueFactory(parameter);
}
}
private static class OptionalContainerRequestValueFactory extends AbstractContainerRequestValueFactory {
private final Parameter parameter;
private OptionalContainerRequestValueFactory(Parameter parameter) {
this.parameter = parameter;
}
@Override
public Object provide() {
Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal();
if (principal != null && !(principal instanceof WebSocketServletRequest.ContextPrincipal)) {
throw new IllegalArgumentException("Can't inject non-ContextPrincipal into request");
}
if (principal == null) return Optional.empty();
else return Optional.ofNullable(((WebSocketServletRequest.ContextPrincipal)principal).getContext().getAuthenticated());
}
}
private static class StandardContainerRequestValueFactory extends AbstractContainerRequestValueFactory {
private final Parameter parameter;
private StandardContainerRequestValueFactory(Parameter parameter) {
this.parameter = parameter;
}
@Override
public Object provide() {
Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal();
if (principal == null) {
throw new IllegalStateException("Cannot inject a custom principal into unauthenticated request");
}
if (!(principal instanceof WebSocketServletRequest.ContextPrincipal)) {
throw new IllegalArgumentException("Cannot inject a non-WebSocket AuthPrincipal into request");
}
Object authenticated = ((WebSocketServletRequest.ContextPrincipal)principal).getContext().getAuthenticated();
if (authenticated == null) {
throw new WebApplicationException("Authenticated resource", 401);
}
if (!parameter.getRawType().isAssignableFrom(authenticated.getClass())) {
throw new IllegalArgumentException("Authenticated principal is of the wrong type: " + authenticated.getClass() + " looking for: " + parameter.getRawType());
}
return parameter.getRawType().cast(authenticated);
}
}
@Singleton
private static class AuthInjectionResolver extends ParamInjectionResolver<Auth> {
public AuthInjectionResolver() {
super(WebSocketAuthValueFactoryProvider.class);
}
}
public static class Binder extends AbstractBinder {
public Binder() {
}
@Override
protected void configure() {
bind(WebSocketAuthValueFactoryProvider.class).to(ValueFactoryProvider.class).in(Singleton.class);
bind(AuthInjectionResolver.class).to(new TypeLiteral<InjectionResolver<Auth>>() {
}).in(Singleton.class);
}
}
}

View File

@@ -2,6 +2,8 @@ package org.whispersystems.websocket.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
@@ -13,9 +15,9 @@ public class WebSocketConfiguration {
@Valid
@NotNull
@JsonProperty
private RequestLogFactory requestLog = new LogbackAccessRequestLogFactory();
private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory();
public RequestLogFactory getRequestLog() {
public WebsocketRequestLoggerFactory getRequestLog() {
return requestLog;
}
}

View File

@@ -0,0 +1,16 @@
package org.whispersystems.websocket.logging;
import ch.qos.logback.core.AsyncAppenderBase;
import io.dropwizard.logging.async.AsyncAppenderFactory;
public class AsyncWebsocketEventAppenderFactory implements AsyncAppenderFactory<WebsocketEvent> {
@Override
public AsyncAppenderBase<WebsocketEvent> build() {
return new AsyncAppenderBase<WebsocketEvent>() {
@Override
protected void preprocess(WebsocketEvent event) {
event.prepareForDeferredProcessing();
}
};
}
}

View File

@@ -0,0 +1,73 @@
package org.whispersystems.websocket.logging;
import com.google.common.annotations.VisibleForTesting;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import javax.ws.rs.core.MultivaluedMap;
import java.util.List;
import ch.qos.logback.core.spi.DeferredProcessingAware;
public class WebsocketEvent implements DeferredProcessingAware {
public static final int SENTINEL = -1;
public static final String NA = "-";
private final String remoteAddress;
private final ContainerRequest request;
private final ContainerResponse response;
private final long timestamp;
public WebsocketEvent(String remoteAddress, ContainerRequest jerseyRequest, ContainerResponse jettyResponse) {
this.timestamp = System.currentTimeMillis();
this.remoteAddress = remoteAddress;
this.request = jerseyRequest;
this.response = jettyResponse;
}
public String getRemoteHost() {
return remoteAddress;
}
public long getTimestamp() {
return timestamp;
}
@Override
public void prepareForDeferredProcessing() {
}
public String getMethod() {
return request.getMethod();
}
public String getPath() {
return request.getBaseUri().getPath() + request.getPath(false);
}
public String getProtocol() {
return "WS";
}
public int getStatusCode() {
return response.getStatus();
}
public long getContentLength() {
return response.getLength();
}
public String getRequestHeader(String key) {
List<String> values = request.getRequestHeader(key);
if (values == null) return NA;
else return values.stream().findFirst().orElse(NA);
}
public MultivaluedMap<String, String> getRequestHeaderMap() {
return request.getRequestHeaders();
}
}

View File

@@ -0,0 +1,43 @@
package org.whispersystems.websocket.logging;
import com.google.common.annotations.VisibleForTesting;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import ch.qos.logback.core.Appender;
import ch.qos.logback.core.filter.Filter;
import ch.qos.logback.core.spi.AppenderAttachableImpl;
import ch.qos.logback.core.spi.FilterAttachableImpl;
import ch.qos.logback.core.spi.FilterReply;
public class WebsocketRequestLog {
private AppenderAttachableImpl<WebsocketEvent> aai = new AppenderAttachableImpl<>();
private FilterAttachableImpl<WebsocketEvent> fai = new FilterAttachableImpl<>();
public WebsocketRequestLog() {
}
public void log(String remoteAddress, ContainerRequest jerseyRequest, ContainerResponse jettyResponse) {
WebsocketEvent event = new WebsocketEvent(remoteAddress, jerseyRequest, jettyResponse);
if (getFilterChainDecision(event) == FilterReply.DENY) {
return;
}
aai.appendLoopOnAppenders(event);
}
public void addAppender(Appender<WebsocketEvent> newAppender) {
aai.addAppender(newAppender);
}
public void addFilter(Filter<WebsocketEvent> newFilter) {
fai.addFilter(newFilter);
}
public FilterReply getFilterChainDecision(WebsocketEvent event) {
return fai.getFilterChainDecision(event);
}
}

View File

@@ -0,0 +1,45 @@
package org.whispersystems.websocket.logging;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.logging.layout.WebsocketEventLayoutFactory;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import java.util.Collections;
import java.util.List;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.LoggerContext;
import io.dropwizard.logging.AppenderFactory;
import io.dropwizard.logging.ConsoleAppenderFactory;
import io.dropwizard.logging.async.AsyncAppenderFactory;
import io.dropwizard.logging.filter.LevelFilterFactory;
import io.dropwizard.logging.filter.NullLevelFilterFactory;
import io.dropwizard.logging.layout.LayoutFactory;
public class WebsocketRequestLoggerFactory {
@VisibleForTesting
@Valid
@NotNull
public List<AppenderFactory<WebsocketEvent>> appenders = Collections.singletonList(new ConsoleAppenderFactory<>());
public WebsocketRequestLog build(String name) {
final Logger logger = (Logger) LoggerFactory.getLogger("websocket.request");
logger.setAdditive(false);
final LoggerContext context = logger.getLoggerContext();
final WebsocketRequestLog requestLog = new WebsocketRequestLog();
final LevelFilterFactory<WebsocketEvent> levelFilterFactory = new NullLevelFilterFactory<>();
final AsyncAppenderFactory<WebsocketEvent> asyncAppenderFactory = new AsyncWebsocketEventAppenderFactory();
final LayoutFactory<WebsocketEvent> layoutFactory = new WebsocketEventLayoutFactory();
for (AppenderFactory<WebsocketEvent> output : appenders) {
requestLog.addAppender(output.build(context, name, layoutFactory, levelFilterFactory, asyncAppenderFactory));
}
return requestLog;
}
}

View File

@@ -0,0 +1,77 @@
package org.whispersystems.websocket.logging.layout;
import org.whispersystems.websocket.logging.WebsocketEvent;
import org.whispersystems.websocket.logging.layout.converters.ContentLengthConverter;
import org.whispersystems.websocket.logging.layout.converters.DateConverter;
import org.whispersystems.websocket.logging.layout.converters.EnsureLineSeparation;
import org.whispersystems.websocket.logging.layout.converters.NAConverter;
import org.whispersystems.websocket.logging.layout.converters.RemoteHostConverter;
import org.whispersystems.websocket.logging.layout.converters.RequestHeaderConverter;
import org.whispersystems.websocket.logging.layout.converters.RequestUrlConverter;
import org.whispersystems.websocket.logging.layout.converters.StatusCodeConverter;
import java.util.HashMap;
import java.util.Map;
import ch.qos.logback.core.Context;
import ch.qos.logback.core.pattern.PatternLayoutBase;
public class WebsocketEventLayout extends PatternLayoutBase<WebsocketEvent> {
private static final Map<String, String> DEFAULT_CONVERTERS = new HashMap<>() {{
put("h", RemoteHostConverter.class.getName());
put("l", NAConverter.class.getName());
put("u", NAConverter.class.getName());
put("t", DateConverter.class.getName());
put("r", RequestUrlConverter.class.getName());
put("s", StatusCodeConverter.class.getName());
put("b", ContentLengthConverter.class.getName());
put("i", RequestHeaderConverter.class.getName());
}};
public static final String CLF_PATTERN = "%h %l %u [%t] \"%r\" %s %b";
public static final String CLF_PATTERN_NAME = "common";
public static final String CLF_PATTERN_NAME_2 = "clf";
public static final String COMBINED_PATTERN = "%h %l %u [%t] \"%r\" %s %b \"%i{Referer}\" \"%i{User-Agent}\"";
public static final String COMBINED_PATTERN_NAME = "combined";
public static final String HEADER_PREFIX = "#logback.access pattern: ";
public WebsocketEventLayout(Context context) {
setOutputPatternAsHeader(false);
setPattern(COMBINED_PATTERN);
setContext(context);
this.postCompileProcessor = new EnsureLineSeparation();
}
@Override
public Map<String, String> getDefaultConverterMap() {
return DEFAULT_CONVERTERS;
}
@Override
public String doLayout(WebsocketEvent event) {
if (!isStarted()) {
return null;
}
return writeLoopOnConverters(event);
}
@Override
public void start() {
if (getPattern().equalsIgnoreCase(CLF_PATTERN_NAME) || getPattern().equalsIgnoreCase(CLF_PATTERN_NAME_2)) {
setPattern(CLF_PATTERN);
} else if (getPattern().equalsIgnoreCase(COMBINED_PATTERN_NAME)) {
setPattern(COMBINED_PATTERN);
}
super.start();
}
@Override
protected String getPresentationHeaderPrefix() {
return HEADER_PREFIX;
}
}

View File

@@ -0,0 +1,16 @@
package org.whispersystems.websocket.logging.layout;
import org.whispersystems.websocket.logging.WebsocketEvent;
import java.util.TimeZone;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.core.pattern.PatternLayoutBase;
import io.dropwizard.logging.layout.LayoutFactory;
public class WebsocketEventLayoutFactory implements LayoutFactory<WebsocketEvent> {
@Override
public PatternLayoutBase<WebsocketEvent> build(LoggerContext context, TimeZone timeZone) {
return new WebsocketEventLayout(context);
}
}

View File

@@ -0,0 +1,14 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
public class ContentLengthConverter extends WebSocketEventConverter {
@Override
public String convert(WebsocketEvent event) {
if (event.getContentLength() == WebsocketEvent.SENTINEL) {
return WebsocketEvent.NA;
} else {
return Long.toString(event.getContentLength());
}
}
}

View File

@@ -0,0 +1,51 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
import java.util.List;
import java.util.TimeZone;
import ch.qos.logback.core.CoreConstants;
import ch.qos.logback.core.util.CachingDateFormatter;
public class DateConverter extends WebSocketEventConverter {
private CachingDateFormatter cachingDateFormatter = null;
@Override
public void start() {
String datePattern = getFirstOption();
if (datePattern == null) {
datePattern = CoreConstants.CLF_DATE_PATTERN;
}
if (datePattern.equals(CoreConstants.ISO8601_STR)) {
datePattern = CoreConstants.ISO8601_PATTERN;
}
try {
cachingDateFormatter = new CachingDateFormatter(datePattern);
// maximumCacheValidity = CachedDateFormat.getMaximumCacheValidity(pattern);
} catch (IllegalArgumentException e) {
addWarn("Could not instantiate SimpleDateFormat with pattern " + datePattern, e);
addWarn("Defaulting to " + CoreConstants.CLF_DATE_PATTERN);
cachingDateFormatter = new CachingDateFormatter(CoreConstants.CLF_DATE_PATTERN);
}
List optionList = getOptionList();
// if the option list contains a TZ option, then set it.
if (optionList != null && optionList.size() > 1) {
TimeZone tz = TimeZone.getTimeZone((String) optionList.get(1));
cachingDateFormatter.setTimeZone(tz);
}
}
@Override
public String convert(WebsocketEvent websocketEvent) {
long timestamp = websocketEvent.getTimestamp();
return cachingDateFormatter.format(timestamp);
}
}

View File

@@ -0,0 +1,29 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
import ch.qos.logback.core.Context;
import ch.qos.logback.core.pattern.Converter;
import ch.qos.logback.core.pattern.ConverterUtil;
import ch.qos.logback.core.pattern.PostCompileProcessor;
public class EnsureLineSeparation implements PostCompileProcessor<WebsocketEvent> {
/**
* Add a line separator converter so that access event appears on a separate
* line.
*/
@Override
public void process(Context context, Converter<WebsocketEvent> head) {
if (head == null)
throw new IllegalArgumentException("Empty converter chain");
// if head != null, then tail != null as well
Converter<WebsocketEvent> tail = ConverterUtil.findTail(head);
Converter<WebsocketEvent> newLineConverter = new LineSeparatorConverter();
if (!(tail instanceof LineSeparatorConverter)) {
tail.setNext(newLineConverter);
}
}
}

View File

@@ -0,0 +1,14 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
import ch.qos.logback.core.CoreConstants;
public class LineSeparatorConverter extends WebSocketEventConverter {
public LineSeparatorConverter() {
}
public String convert(WebsocketEvent event) {
return CoreConstants.LINE_SEPARATOR;
}
}

View File

@@ -0,0 +1,10 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
public class NAConverter extends WebSocketEventConverter {
@Override
public String convert(WebsocketEvent event) {
return WebsocketEvent.NA;
}
}

View File

@@ -0,0 +1,10 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
public class RemoteHostConverter extends WebSocketEventConverter {
@Override
public String convert(WebsocketEvent event) {
return event.getRemoteHost();
}
}

View File

@@ -0,0 +1,33 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
import ch.qos.logback.core.util.OptionHelper;
public class RequestHeaderConverter extends WebSocketEventConverter {
private String key;
@Override
public void start() {
key = getFirstOption();
if (OptionHelper.isEmpty(key)) {
addWarn("Missing key for the requested header. Defaulting to all keys.");
key = null;
}
super.start();
}
@Override
public String convert(WebsocketEvent websocketEvent) {
if (!isStarted()) {
return "INACTIVE_HEADER_CONV";
}
if (key != null) {
return websocketEvent.getRequestHeader(key);
} else {
return websocketEvent.getRequestHeaderMap().toString();
}
}
}

View File

@@ -0,0 +1,15 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
public class RequestUrlConverter extends WebSocketEventConverter {
@Override
public String convert(WebsocketEvent event) {
return
event.getMethod() +
WebSocketEventConverter.SPACE_CHAR +
event.getPath() +
WebSocketEventConverter.SPACE_CHAR +
event.getProtocol();
}
}

View File

@@ -0,0 +1,14 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
public class StatusCodeConverter extends WebSocketEventConverter {
@Override
public String convert(WebsocketEvent event) {
if (event.getStatusCode() == WebsocketEvent.SENTINEL) {
return WebsocketEvent.NA;
} else {
return Integer.toString(event.getStatusCode());
}
}
}

View File

@@ -0,0 +1,63 @@
package org.whispersystems.websocket.logging.layout.converters;
import org.whispersystems.websocket.logging.WebsocketEvent;
import ch.qos.logback.core.Context;
import ch.qos.logback.core.pattern.DynamicConverter;
import ch.qos.logback.core.spi.ContextAware;
import ch.qos.logback.core.spi.ContextAwareBase;
import ch.qos.logback.core.status.Status;
public abstract class WebSocketEventConverter extends DynamicConverter<WebsocketEvent> implements ContextAware {
public final static char SPACE_CHAR = ' ';
public final static char QUESTION_CHAR = '?';
ContextAwareBase cab = new ContextAwareBase();
@Override
public void setContext(Context context) {
cab.setContext(context);
}
@Override
public Context getContext() {
return cab.getContext();
}
@Override
public void addStatus(Status status) {
cab.addStatus(status);
}
@Override
public void addInfo(String msg) {
cab.addInfo(msg);
}
@Override
public void addInfo(String msg, Throwable ex) {
cab.addInfo(msg, ex);
}
@Override
public void addWarn(String msg) {
cab.addWarn(msg);
}
@Override
public void addWarn(String msg, Throwable ex) {
cab.addWarn(msg, ex);
}
@Override
public void addError(String msg) {
cab.addError(msg);
}
@Override
public void addError(String msg, Throwable ex) {
cab.addError(msg, ex);
}
}

View File

@@ -1,66 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
public class BufferingServletInputStream extends ServletInputStream {
private final ByteArrayInputStream buffer;
public BufferingServletInputStream(byte[] body) {
this.buffer = new ByteArrayInputStream(body);
}
@Override
public int read(byte[] buf, int offset, int length) {
return buffer.read(buf, offset, length);
}
@Override
public int read(byte[] buf) {
return read(buf, 0, buf.length);
}
@Override
public int read() throws IOException {
return buffer.read();
}
@Override
public int available() {
return buffer.available();
}
@Override
public boolean isFinished() {
return available() > 0;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
}

View File

@@ -1,66 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class BufferingServletOutputStream extends ServletOutputStream {
private final ByteArrayOutputStream buffer;
public BufferingServletOutputStream(ByteArrayOutputStream buffer) {
this.buffer = buffer;
}
@Override
public void write(byte[] buf, int offset, int length) {
buffer.write(buf, offset, length);
}
@Override
public void write(byte[] buf) {
write(buf, 0, buf.length);
}
@Override
public void write(int b) throws IOException {
buffer.write(b);
}
@Override
public void flush() {
}
@Override
public void close() {
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setWriteListener(WriteListener writeListener) {
}
}

View File

@@ -1,629 +0,0 @@
package org.whispersystems.websocket.servlet;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.server.Authentication;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.HttpChannelState;
import org.eclipse.jetty.server.HttpInput;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.UserIdentity;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.Attributes;
import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.Part;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.util.Collection;
import java.util.Enumeration;
import java.util.EventListener;
import java.util.Locale;
import java.util.Map;
public class LoggableRequest extends Request {
private final HttpServletRequest request;
public LoggableRequest(HttpServletRequest request) {
super(null, null);
this.request = request;
}
@Override
public HttpFields getHttpFields() {
throw new AssertionError();
}
@Override
public HttpInput getHttpInput() {
throw new AssertionError();
}
@Override
public void addEventListener(EventListener listener) {
throw new AssertionError();
}
@Override
public AsyncContext getAsyncContext() {
throw new AssertionError();
}
@Override
public HttpChannelState getHttpChannelState() {
throw new AssertionError();
}
@Override
public Object getAttribute(String name) {
return request.getAttribute(name);
}
@Override
public Enumeration<String> getAttributeNames() {
return request.getAttributeNames();
}
@Override
public Attributes getAttributes() {
throw new AssertionError();
}
@Override
public Authentication getAuthentication() {
return null;
}
@Override
public String getAuthType() {
return request.getAuthType();
}
@Override
public String getCharacterEncoding() {
return request.getCharacterEncoding();
}
@Override
public HttpChannel getHttpChannel() {
throw new AssertionError();
}
@Override
public int getContentLength() {
return request.getContentLength();
}
@Override
public String getContentType() {
return request.getContentType();
}
@Override
public ContextHandler.Context getContext() {
throw new AssertionError();
}
@Override
public String getContextPath() {
return request.getContextPath();
}
@Override
public Cookie[] getCookies() {
return request.getCookies();
}
@Override
public long getDateHeader(String name) {
return request.getDateHeader(name);
}
@Override
public DispatcherType getDispatcherType() {
return request.getDispatcherType();
}
@Override
public String getHeader(String name) {
return request.getHeader(name);
}
@Override
public Enumeration<String> getHeaderNames() {
return request.getHeaderNames();
}
@Override
public Enumeration<String> getHeaders(String name) {
return request.getHeaders(name);
}
@Override
public int getInputState() {
throw new AssertionError();
}
@Override
public ServletInputStream getInputStream() throws IOException {
return request.getInputStream();
}
@Override
public int getIntHeader(String name) {
return request.getIntHeader(name);
}
@Override
public Locale getLocale() {
return request.getLocale();
}
@Override
public Enumeration<Locale> getLocales() {
return request.getLocales();
}
@Override
public String getLocalAddr() {
return request.getLocalAddr();
}
@Override
public String getLocalName() {
return request.getLocalName();
}
@Override
public int getLocalPort() {
return request.getLocalPort();
}
@Override
public String getMethod() {
return request.getMethod();
}
@Override
public String getParameter(String name) {
return request.getParameter(name);
}
@Override
public Map<String, String[]> getParameterMap() {
return request.getParameterMap();
}
@Override
public Enumeration<String> getParameterNames() {
return request.getParameterNames();
}
@Override
public String[] getParameterValues(String name) {
return request.getParameterValues(name);
}
@Override
public String getPathInfo() {
return request.getPathInfo();
}
@Override
public String getPathTranslated() {
return request.getPathTranslated();
}
@Override
public String getProtocol() {
return request.getProtocol();
}
@Override
public HttpVersion getHttpVersion() {
throw new AssertionError();
}
@Override
public String getQueryEncoding() {
throw new AssertionError();
}
@Override
public String getQueryString() {
return request.getQueryString();
}
@Override
public BufferedReader getReader() throws IOException {
throw new AssertionError();
}
@Override
public String getRealPath(String path) {
return request.getRealPath(path);
}
@Override
public String getRemoteAddr() {
return request.getRemoteAddr();
}
@Override
public String getRemoteHost() {
return request.getRemoteHost();
}
@Override
public int getRemotePort() {
return request.getRemotePort();
}
@Override
public String getRemoteUser() {
return request.getRemoteUser();
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
return request.getRequestDispatcher(path);
}
@Override
public String getRequestedSessionId() {
return request.getRequestedSessionId();
}
@Override
public String getRequestURI() {
return request.getRequestURI();
}
@Override
public StringBuffer getRequestURL() {
return request.getRequestURL();
}
@Override
public Response getResponse() {
throw new AssertionError();
}
@Override
public StringBuilder getRootURL() {
throw new AssertionError();
}
@Override
public String getScheme() {
return request.getScheme();
}
@Override
public String getServerName() {
return request.getServerName();
}
@Override
public int getServerPort() {
return request.getServerPort();
}
@Override
public ServletContext getServletContext() {
return request.getServletContext();
}
@Override
public String getServletName() {
throw new AssertionError();
}
@Override
public String getServletPath() {
return request.getServletPath();
}
@Override
public ServletResponse getServletResponse() {
throw new AssertionError();
}
@Override
public String changeSessionId() {
throw new AssertionError();
}
@Override
public HttpSession getSession() {
return request.getSession();
}
@Override
public HttpSession getSession(boolean create) {
return request.getSession(create);
}
@Override
public long getTimeStamp() {
return System.currentTimeMillis();
}
@Override
public HttpURI getHttpURI() {
return new HttpURI(getRequestURI());
}
@Override
public UserIdentity getUserIdentity() {
throw new AssertionError();
}
@Override
public UserIdentity getResolvedUserIdentity() {
throw new AssertionError();
}
@Override
public UserIdentity.Scope getUserIdentityScope() {
throw new AssertionError();
}
@Override
public Principal getUserPrincipal() {
throw new AssertionError();
}
@Override
public boolean isHandled() {
throw new AssertionError();
}
@Override
public boolean isAsyncStarted() {
return request.isAsyncStarted();
}
@Override
public boolean isAsyncSupported() {
return request.isAsyncSupported();
}
@Override
public boolean isRequestedSessionIdFromCookie() {
return request.isRequestedSessionIdFromCookie();
}
@Override
public boolean isRequestedSessionIdFromUrl() {
return request.isRequestedSessionIdFromUrl();
}
@Override
public boolean isRequestedSessionIdFromURL() {
return request.isRequestedSessionIdFromURL();
}
@Override
public boolean isRequestedSessionIdValid() {
return request.isRequestedSessionIdValid();
}
@Override
public boolean isSecure() {
return request.isSecure();
}
@Override
public void setSecure(boolean secure) {
throw new AssertionError();
}
@Override
public boolean isUserInRole(String role) {
return request.isUserInRole(role);
}
@Override
public void removeAttribute(String name) {
request.removeAttribute(name);
}
@Override
public void removeEventListener(EventListener listener) {
throw new AssertionError();
}
@Override
public void setAsyncSupported(boolean supported, String source) {
throw new AssertionError();
}
@Override
public void setAttribute(String name, Object value) {
throw new AssertionError();
}
@Override
public void setAttributes(Attributes attributes) {
throw new AssertionError();
}
@Override
public void setAuthentication(Authentication authentication) {
throw new AssertionError();
}
@Override
public void setCharacterEncoding(String encoding) throws UnsupportedEncodingException {
throw new AssertionError();
}
@Override
public void setCharacterEncodingUnchecked(String encoding) {
throw new AssertionError();
}
@Override
public void setContentType(String contentType) {
throw new AssertionError();
}
@Override
public void setContext(ContextHandler.Context context) {
throw new AssertionError();
}
@Override
public boolean takeNewContext() {
throw new AssertionError();
}
@Override
public void setContextPath(String contextPath) {
throw new AssertionError();
}
@Override
public void setCookies(Cookie[] cookies) {
throw new AssertionError();
}
@Override
public void setDispatcherType(DispatcherType type) {
throw new AssertionError();
}
@Override
public void setHandled(boolean h) {
throw new AssertionError();
}
@Override
public boolean isHead() {
throw new AssertionError();
}
@Override
public void setPathInfo(String pathInfo) {
throw new AssertionError();
}
@Override
public void setHttpVersion(HttpVersion version) {
throw new AssertionError();
}
@Override
public void setQueryEncoding(String queryEncoding) {
throw new AssertionError();
}
@Override
public void setQueryString(String queryString) {
throw new AssertionError();
}
@Override
public void setRemoteAddr(InetSocketAddress addr) {
throw new AssertionError();
}
@Override
public void setRequestedSessionId(String requestedSessionId) {
throw new AssertionError();
}
@Override
public void setRequestedSessionIdFromCookie(boolean requestedSessionIdCookie) {
throw new AssertionError();
}
@Override
public void setScheme(String scheme) {
throw new AssertionError();
}
@Override
public void setServletPath(String servletPath) {
throw new AssertionError();
}
@Override
public void setSession(HttpSession session) {
throw new AssertionError();
}
@Override
public void setTimeStamp(long ts) {
throw new AssertionError();
}
@Override
public void setHttpURI(HttpURI uri) {
throw new AssertionError();
}
@Override
public void setUserIdentityScope(UserIdentity.Scope scope) {
throw new AssertionError();
}
@Override
public AsyncContext startAsync() throws IllegalStateException {
throw new AssertionError();
}
@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
throw new AssertionError();
}
@Override
public String toString() {
return request.toString();
}
@Override
public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
throw new AssertionError();
}
@Override
public Part getPart(String name) throws IOException, ServletException {
return request.getPart(name);
}
@Override
public Collection<Part> getParts() throws IOException, ServletException {
return request.getParts();
}
@Override
public void login(String username, String password) throws ServletException {
throw new AssertionError();
}
@Override
public void logout() throws ServletException {
throw new AssertionError();
}
}

View File

@@ -1,449 +0,0 @@
package org.whispersystems.websocket.servlet;
import org.eclipse.jetty.http.HttpContent;
import org.eclipse.jetty.http.HttpCookie;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.http.MetaData;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpOutput;
import org.eclipse.jetty.server.HttpTransport;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.util.Callback;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ReadPendingException;
import java.nio.channels.WritePendingException;
import java.util.Collection;
import java.util.Locale;
public class LoggableResponse extends Response {
private final HttpServletResponse response;
public LoggableResponse(HttpServletResponse response) {
super(null, null);
this.response = response;
}
@Override
public void putHeaders(HttpContent httpContent, long contentLength, boolean etag) {
throw new AssertionError();
}
@Override
public HttpOutput getHttpOutput() {
throw new AssertionError();
}
@Override
public boolean isIncluding() {
throw new AssertionError();
}
@Override
public void include() {
throw new AssertionError();
}
@Override
public void included() {
throw new AssertionError();
}
@Override
public void addCookie(HttpCookie cookie) {
throw new AssertionError();
}
@Override
public void addCookie(Cookie cookie) {
throw new AssertionError();
}
@Override
public boolean containsHeader(String name) {
return response.containsHeader(name);
}
@Override
public String encodeURL(String url) {
return response.encodeURL(url);
}
@Override
public String encodeRedirectURL(String url) {
return response.encodeRedirectURL(url);
}
@Override
public String encodeUrl(String url) {
return response.encodeUrl(url);
}
@Override
public String encodeRedirectUrl(String url) {
return response.encodeRedirectUrl(url);
}
@Override
public void sendError(int sc) throws IOException {
throw new AssertionError();
}
@Override
public void sendError(int code, String message) throws IOException {
throw new AssertionError();
}
@Override
public void sendProcessing() throws IOException {
throw new AssertionError();
}
@Override
public void sendRedirect(String location) throws IOException {
throw new AssertionError();
}
@Override
public void setDateHeader(String name, long date) {
throw new AssertionError();
}
@Override
public void addDateHeader(String name, long date) {
throw new AssertionError();
}
@Override
public void setHeader(HttpHeader name, String value) {
throw new AssertionError();
}
@Override
public void setHeader(String name, String value) {
throw new AssertionError();
}
@Override
public Collection<String> getHeaderNames() {
return response.getHeaderNames();
}
@Override
public String getHeader(String name) {
return response.getHeader(name);
}
@Override
public Collection<String> getHeaders(String name) {
return response.getHeaders(name);
}
@Override
public void addHeader(String name, String value) {
throw new AssertionError();
}
@Override
public void setIntHeader(String name, int value) {
throw new AssertionError();
}
@Override
public void addIntHeader(String name, int value) {
throw new AssertionError();
}
@Override
public void setStatus(int sc) {
throw new AssertionError();
}
@Override
public void setStatus(int sc, String sm) {
throw new AssertionError();
}
@Override
public void setStatusWithReason(int sc, String sm) {
throw new AssertionError();
}
@Override
public String getCharacterEncoding() {
return response.getCharacterEncoding();
}
@Override
public String getContentType() {
return response.getContentType();
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
throw new AssertionError();
}
@Override
public boolean isWriting() {
throw new AssertionError();
}
@Override
public PrintWriter getWriter() throws IOException {
throw new AssertionError();
}
@Override
public void setContentLength(int len) {
throw new AssertionError();
}
@Override
public boolean isAllContentWritten(long written) {
throw new AssertionError();
}
@Override
public void closeOutput() throws IOException {
throw new AssertionError();
}
@Override
public long getLongContentLength() {
return response.getBufferSize();
}
@Override
public void setLongContentLength(long len) {
throw new AssertionError();
}
@Override
public void setCharacterEncoding(String encoding) {
throw new AssertionError();
}
@Override
public void setContentType(String contentType) {
throw new AssertionError();
}
@Override
public void setBufferSize(int size) {
throw new AssertionError();
}
@Override
public int getBufferSize() {
return response.getBufferSize();
}
@Override
public void flushBuffer() throws IOException {
throw new AssertionError();
}
@Override
public void reset() {
throw new AssertionError();
}
@Override
public void reset(boolean preserveCookies) {
throw new AssertionError();
}
@Override
public void resetForForward() {
throw new AssertionError();
}
@Override
public void resetBuffer() {
throw new AssertionError();
}
@Override
public boolean isCommitted() {
throw new AssertionError();
}
@Override
public void setLocale(Locale locale) {
throw new AssertionError();
}
@Override
public Locale getLocale() {
return response.getLocale();
}
@Override
public int getStatus() {
return response.getStatus();
}
@Override
public String getReason() {
throw new AssertionError();
}
@Override
public HttpFields getHttpFields() {
return new HttpFields();
}
@Override
public long getContentCount() {
return 0;
}
@Override
public String toString() {
return response.toString();
}
@Override
public MetaData.Response getCommittedMetaData() {
return new MetaData.Response(HttpVersion.HTTP_2, getStatus(), null);
}
@Override
public HttpChannel getHttpChannel()
{
return new HttpChannel(null, new HttpConfiguration(), new NullEndPoint(), null);
}
private static class NullEndPoint implements EndPoint {
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public boolean isOpen() {
return false;
}
@Override
public long getCreatedTimeStamp() {
return 0;
}
@Override
public void shutdownOutput() {
}
@Override
public boolean isOutputShutdown() {
return false;
}
@Override
public boolean isInputShutdown() {
return false;
}
@Override
public void close() {
}
@Override
public int fill(ByteBuffer buffer) throws IOException {
return 0;
}
@Override
public boolean flush(ByteBuffer... buffer) throws IOException {
return false;
}
@Override
public Object getTransport() {
return null;
}
@Override
public long getIdleTimeout() {
return 0;
}
@Override
public void setIdleTimeout(long idleTimeout) {
}
@Override
public void fillInterested(Callback callback) throws ReadPendingException {
}
@Override
public boolean tryFillInterested(Callback callback) {
return false;
}
@Override
public boolean isFillInterested() {
return false;
}
@Override
public void write(Callback callback, ByteBuffer... buffers) throws WritePendingException {
}
@Override
public Connection getConnection() {
return null;
}
@Override
public void setConnection(Connection connection) {
}
@Override
public void onOpen() {
}
@Override
public void onClose() {
}
@Override
public boolean isOptimizedForDirectBuffers() {
return false;
}
@Override
public void upgrade(Connection newConnection) {
}
}
}

View File

@@ -1,42 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import java.io.IOException;
public class NullServletOutputStream extends ServletOutputStream {
@Override
public void write(int b) throws IOException {}
@Override
public void write(byte[] buf) {}
@Override
public void write(byte[] buf, int offset, int len) {}
@Override
public boolean isReady() {
return false;
}
@Override
public void setWriteListener(WriteListener writeListener) {
}
}

View File

@@ -1,171 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Locale;
public class NullServletResponse implements HttpServletResponse {
@Override
public void addCookie(Cookie cookie) {}
@Override
public boolean containsHeader(String name) {
return false;
}
@Override
public String encodeURL(String url) {
return url;
}
@Override
public String encodeRedirectURL(String url) {
return url;
}
@Override
public String encodeUrl(String url) {
return url;
}
@Override
public String encodeRedirectUrl(String url) {
return url;
}
@Override
public void sendError(int sc, String msg) throws IOException {}
@Override
public void sendError(int sc) throws IOException {}
@Override
public void sendRedirect(String location) throws IOException {}
@Override
public void setDateHeader(String name, long date) {}
@Override
public void addDateHeader(String name, long date) {}
@Override
public void setHeader(String name, String value) {}
@Override
public void addHeader(String name, String value) {}
@Override
public void setIntHeader(String name, int value) {}
@Override
public void addIntHeader(String name, int value) {}
@Override
public void setStatus(int sc) {}
@Override
public void setStatus(int sc, String sm) {}
@Override
public int getStatus() {
return 200;
}
@Override
public String getHeader(String name) {
return null;
}
@Override
public Collection<String> getHeaders(String name) {
return new LinkedList<>();
}
@Override
public Collection<String> getHeaderNames() {
return new LinkedList<>();
}
@Override
public String getCharacterEncoding() {
return "UTF-8";
}
@Override
public String getContentType() {
return null;
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
return new NullServletOutputStream();
}
@Override
public PrintWriter getWriter() throws IOException {
return new PrintWriter(new NullServletOutputStream());
}
@Override
public void setCharacterEncoding(String charset) {}
@Override
public void setContentLength(int len) {}
@Override
public void setContentLengthLong(long len) {}
@Override
public void setContentType(String type) {}
@Override
public void setBufferSize(int size) {}
@Override
public int getBufferSize() {
return 0;
}
@Override
public void flushBuffer() throws IOException {}
@Override
public void resetBuffer() {}
@Override
public boolean isCommitted() {
return true;
}
@Override
public void reset() {}
@Override
public void setLocale(Locale loc) {}
@Override
public Locale getLocale() {
return Locale.US;
}
}

View File

@@ -1,506 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.servlet.AsyncContext;
import javax.servlet.DispatcherType;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.Part;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.security.Principal;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Locale;
import java.util.Map;
import java.util.Vector;
public class WebSocketServletRequest implements HttpServletRequest {
private final Map<String, String> headers = new HashMap<>();
private final Map<String, Object> attributes = new HashMap<>();
private final WebSocketRequestMessage requestMessage;
private final ServletInputStream inputStream;
private final ServletContext servletContext;
private final WebSocketSessionContext sessionContext;
public WebSocketServletRequest(WebSocketSessionContext sessionContext,
WebSocketRequestMessage requestMessage,
ServletContext servletContext)
{
this.requestMessage = requestMessage;
this.servletContext = servletContext;
this.sessionContext = sessionContext;
if (requestMessage.getBody().isPresent()) {
inputStream = new BufferingServletInputStream(requestMessage.getBody().get());
} else {
inputStream = new BufferingServletInputStream(new byte[0]);
}
headers.putAll(requestMessage.getHeaders());
}
@Override
public String getAuthType() {
return BASIC_AUTH;
}
@Override
public Cookie[] getCookies() {
return new Cookie[0];
}
@Override
public long getDateHeader(String name) {
return -1;
}
@Override
public String getHeader(String name) {
return headers.get(name.toLowerCase());
}
@Override
public Enumeration<String> getHeaders(String name) {
String header = this.headers.get(name.toLowerCase());
Vector<String> results = new Vector<>();
if (header != null) {
results.add(header);
}
return results.elements();
}
@Override
public Enumeration<String> getHeaderNames() {
return new Vector<>(headers.keySet()).elements();
}
@Override
public int getIntHeader(String name) {
return -1;
}
@Override
public String getMethod() {
return requestMessage.getVerb();
}
@Override
public String getPathInfo() {
return requestMessage.getPath();
}
@Override
public String getPathTranslated() {
return requestMessage.getPath();
}
@Override
public String getContextPath() {
return "";
}
@Override
public String getQueryString() {
if (requestMessage.getPath().contains("?")) {
return requestMessage.getPath().substring(requestMessage.getPath().indexOf("?") + 1);
}
return null;
}
@Override
public String getRemoteUser() {
return null;
}
@Override
public boolean isUserInRole(String role) {
return false;
}
@Override
public Principal getUserPrincipal() {
return new ContextPrincipal(sessionContext);
}
@Override
public String getRequestedSessionId() {
return null;
}
@Override
public String getRequestURI() {
if (requestMessage.getPath().contains("?")) {
return requestMessage.getPath().substring(0, requestMessage.getPath().indexOf("?"));
} else {
return requestMessage.getPath();
}
}
@Override
public StringBuffer getRequestURL() {
StringBuffer stringBuffer = new StringBuffer();
stringBuffer.append("http://websocket");
stringBuffer.append(getRequestURI());
return stringBuffer;
}
@Override
public String getServletPath() {
return "";
}
@Override
public HttpSession getSession(boolean create) {
return null;
}
@Override
public HttpSession getSession() {
return null;
}
@Override
public String changeSessionId() {
return null;
}
@Override
public boolean isRequestedSessionIdValid() {
return false;
}
@Override
public boolean isRequestedSessionIdFromCookie() {
return false;
}
@Override
public boolean isRequestedSessionIdFromURL() {
return false;
}
@Override
public boolean isRequestedSessionIdFromUrl() {
return false;
}
@Override
public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
return false;
}
@Override
public void login(String username, String password) throws ServletException {
}
@Override
public void logout() throws ServletException {
}
@Override
public Collection<Part> getParts() throws IOException, ServletException {
return new LinkedList<>();
}
@Override
public Part getPart(String name) throws IOException, ServletException {
return null;
}
@Override
public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException {
return null;
}
@Override
public Object getAttribute(String name) {
return attributes.get(name);
}
@Override
public Enumeration<String> getAttributeNames() {
return new Vector<>(attributes.keySet()).elements();
}
@Override
public String getCharacterEncoding() {
return null;
}
@Override
public void setCharacterEncoding(String env) throws UnsupportedEncodingException {}
@Override
public int getContentLength() {
if (requestMessage.getBody().isPresent()) {
return requestMessage.getBody().get().length;
} else {
return 0;
}
}
@Override
public long getContentLengthLong() {
return getContentLength();
}
@Override
public String getContentType() {
if (requestMessage.getBody().isPresent()) {
return "application/json";
} else {
return null;
}
}
@Override
public ServletInputStream getInputStream() throws IOException {
return inputStream;
}
@Override
public String getParameter(String name) {
String[] result = getParameterMap().get(name);
if (result != null && result.length > 0) {
return result[0];
}
return null;
}
@Override
public Enumeration<String> getParameterNames() {
return new Vector<>(getParameterMap().keySet()).elements();
}
@Override
public String[] getParameterValues(String name) {
return getParameterMap().get(name);
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> parameterMap = new HashMap<>();
String queryParameters = getQueryString();
if (queryParameters == null) {
return parameterMap;
}
String[] tokens = queryParameters.split("&");
for (String token : tokens) {
String[] parts = token.split("=");
if (parts != null && parts.length > 1) {
parameterMap.put(parts[0], new String[] {parts[1]});
}
}
return parameterMap;
}
@Override
public String getProtocol() {
return "HTTP/1.0";
}
@Override
public String getScheme() {
return "http";
}
@Override
public String getServerName() {
return "websocket";
}
@Override
public int getServerPort() {
return 8080;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(inputStream));
}
@Override
public String getRemoteAddr() {
return "127.0.0.1";
}
@Override
public String getRemoteHost() {
return "localhost";
}
@Override
public void setAttribute(String name, Object o) {
if (o != null) attributes.put(name, o);
else removeAttribute(name);
}
@Override
public void removeAttribute(String name) {
attributes.remove(name);
}
@Override
public Locale getLocale() {
return Locale.US;
}
@Override
public Enumeration<Locale> getLocales() {
Vector<Locale> results = new Vector<>();
results.add(getLocale());
return results.elements();
}
@Override
public boolean isSecure() {
return false;
}
@Override
public RequestDispatcher getRequestDispatcher(String path) {
return servletContext.getRequestDispatcher(path);
}
@Override
public String getRealPath(String path) {
return path;
}
@Override
public int getRemotePort() {
return 31337;
}
@Override
public String getLocalName() {
return "localhost";
}
@Override
public String getLocalAddr() {
return "127.0.0.1";
}
@Override
public int getLocalPort() {
return 8080;
}
@Override
public ServletContext getServletContext() {
return servletContext;
}
@Override
public AsyncContext startAsync() throws IllegalStateException {
throw new AssertionError("nyi");
}
@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
throw new AssertionError("nyi");
}
@Override
public boolean isAsyncStarted() {
return false;
}
@Override
public boolean isAsyncSupported() {
return false;
}
@Override
public AsyncContext getAsyncContext() {
return null;
}
@Override
public DispatcherType getDispatcherType() {
return DispatcherType.REQUEST;
}
public static class ContextPrincipal implements Principal {
private final WebSocketSessionContext context;
public ContextPrincipal(WebSocketSessionContext context) {
this.context = context;
}
@Override
public boolean equals(Object another) {
return another instanceof ContextPrincipal &&
context.equals(((ContextPrincipal) another).context);
}
@Override
public String toString() {
return super.toString();
}
@Override
public int hashCode() {
return context.hashCode();
}
@Override
public String getName() {
return "WebSocketSessionContext";
}
public WebSocketSessionContext getContext() {
return context;
}
}
}

View File

@@ -1,270 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.websocket.servlet;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Locale;
import java.util.Optional;
public class WebSocketServletResponse implements HttpServletResponse {
@SuppressWarnings("unused")
private static final Logger logger = LoggerFactory.getLogger(WebSocketServletResponse.class);
private final RemoteEndpoint endPoint;
private final long requestId;
private final WebSocketMessageFactory messageFactory;
private ResponseBuilder responseBuilder = new ResponseBuilder();
private ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
private ServletOutputStream servletOutputStream = new BufferingServletOutputStream(responseBody);
private boolean isCommitted = false;
public WebSocketServletResponse(RemoteEndpoint endPoint, long requestId,
WebSocketMessageFactory messageFactory)
{
this.endPoint = endPoint;
this.requestId = requestId;
this.messageFactory = messageFactory;
this.responseBuilder.setRequestId(requestId);
}
@Override
public void addCookie(Cookie cookie) {}
@Override
public boolean containsHeader(String name) {
return false;
}
@Override
public String encodeURL(String url) {
return url;
}
@Override
public String encodeRedirectURL(String url) {
return url;
}
@Override
public String encodeUrl(String url) {
return url;
}
@Override
public String encodeRedirectUrl(String url) {
return url;
}
@Override
public void sendError(int sc, String msg) throws IOException {
setStatus(sc, msg);
}
@Override
public void sendError(int sc) throws IOException {
setStatus(sc);
}
@Override
public void sendRedirect(String location) throws IOException {
throw new IOException("Not supported!");
}
@Override
public void setDateHeader(String name, long date) {}
@Override
public void addDateHeader(String name, long date) {}
@Override
public void setHeader(String name, String value) {}
@Override
public void addHeader(String name, String value) {}
@Override
public void setIntHeader(String name, int value) {}
@Override
public void addIntHeader(String name, int value) {}
@Override
public void setStatus(int sc) {
setStatus(sc, "");
}
@Override
public void setStatus(int sc, String sm) {
this.responseBuilder.setStatusCode(sc);
this.responseBuilder.setMessage(sm);
}
@Override
public int getStatus() {
return this.responseBuilder.getStatusCode();
}
@Override
public String getHeader(String name) {
return null;
}
@Override
public Collection<String> getHeaders(String name) {
return new LinkedList<>();
}
@Override
public Collection<String> getHeaderNames() {
return new LinkedList<>();
}
@Override
public String getCharacterEncoding() {
return "UTF-8";
}
@Override
public String getContentType() {
return null;
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
return servletOutputStream;
}
@Override
public PrintWriter getWriter() throws IOException {
return new PrintWriter(servletOutputStream);
}
@Override
public void setCharacterEncoding(String charset) {}
@Override
public void setContentLength(int len) {}
@Override
public void setContentLengthLong(long len) {}
@Override
public void setContentType(String type) {}
@Override
public void setBufferSize(int size) {}
@Override
public int getBufferSize() {
return 0;
}
@Override
public void flushBuffer() throws IOException {
if (!isCommitted) {
byte[] body = responseBody.toByteArray();
if (body.length <= 0) {
body = null;
}
byte[] response = messageFactory.createResponse(responseBuilder.getRequestId(),
responseBuilder.getStatusCode(),
responseBuilder.getMessage(),
new LinkedList<>(),
Optional.ofNullable(body))
.toByteArray();
endPoint.sendBytesByFuture(ByteBuffer.wrap(response));
isCommitted = true;
}
}
@Override
public void resetBuffer() {
if (isCommitted) throw new IllegalStateException("Buffer already flushed!");
responseBody.reset();
}
@Override
public boolean isCommitted() {
return isCommitted;
}
@Override
public void reset() {
if (isCommitted) throw new IllegalStateException("Buffer already flushed!");
responseBuilder = new ResponseBuilder();
responseBuilder.setRequestId(requestId);
resetBuffer();
}
@Override
public void setLocale(Locale loc) {}
@Override
public Locale getLocale() {
return Locale.US;
}
private static class ResponseBuilder {
private long requestId;
private int statusCode;
private String message = "";
public long getRequestId() {
return requestId;
}
public void setRequestId(long requestId) {
this.requestId = requestId;
}
public int getStatusCode() {
return statusCode;
}
public void setStatusCode(int statusCode) {
this.statusCode = statusCode;
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
}
}

View File

@@ -0,0 +1,37 @@
package org.whispersystems.websocket.session;
import java.security.Principal;
public class ContextPrincipal implements Principal {
private final WebSocketSessionContext context;
public ContextPrincipal(WebSocketSessionContext context) {
this.context = context;
}
@Override
public boolean equals(Object another) {
return another instanceof ContextPrincipal &&
context.equals(((ContextPrincipal) another).context);
}
@Override
public String toString() {
return super.toString();
}
@Override
public int hashCode() {
return context.hashCode();
}
@Override
public String getName() {
return "WebSocketSessionContext";
}
public WebSocketSessionContext getContext() {
return context;
}
}

View File

@@ -0,0 +1,31 @@
package org.whispersystems.websocket.session;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.websocket.WebSocketSecurityContext;
import javax.ws.rs.core.SecurityContext;
public class WebSocketSessionContainerRequestValueFactory {
private final ContainerRequest request;
public WebSocketSessionContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public WebSocketSessionContext provide() {
SecurityContext securityContext = request.getSecurityContext();
if (!(securityContext instanceof WebSocketSecurityContext)) {
throw new IllegalStateException("Security context isn't for websocket!");
}
WebSocketSessionContext sessionContext = ((WebSocketSecurityContext)securityContext).getSessionContext();
if (sessionContext == null) {
throw new IllegalStateException("No session context found for websocket!");
}
return sessionContext;
}
}

View File

@@ -1,73 +1,45 @@
package org.whispersystems.websocket.session;
import org.glassfish.hk2.api.InjectionResolver;
import org.glassfish.hk2.api.ServiceLocator;
import org.glassfish.hk2.api.TypeLiteral;
import org.glassfish.hk2.utilities.binding.AbstractBinder;
import org.glassfish.jersey.server.internal.inject.AbstractContainerRequestValueFactory;
import org.glassfish.jersey.server.internal.inject.AbstractValueFactoryProvider;
import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider;
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
import org.glassfish.jersey.server.internal.inject.ParamInjectionResolver;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueFactoryProvider;
import org.whispersystems.websocket.servlet.WebSocketServletRequest;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import java.security.Principal;
import java.util.function.Function;
@Singleton
public class WebSocketSessionContextValueFactoryProvider extends AbstractValueFactoryProvider {
public class WebSocketSessionContextValueFactoryProvider extends AbstractValueParamProvider {
@Inject
public WebSocketSessionContextValueFactoryProvider(MultivaluedParameterExtractorProvider mpep,
ServiceLocator injector)
{
super(mpep, injector, Parameter.Source.UNKNOWN);
public WebSocketSessionContextValueFactoryProvider(MultivaluedParameterExtractorProvider mpep) {
super(() -> mpep, Parameter.Source.UNKNOWN);
}
@Nullable
@Override
public AbstractContainerRequestValueFactory<WebSocketSessionContext> createValueFactory(Parameter parameter) {
if (parameter.getAnnotation(WebSocketSession.class) == null) {
protected Function<ContainerRequest, ?> createValueProvider(Parameter parameter) {
if (!parameter.isAnnotationPresent(WebSocketSession.class)) {
return null;
}
return new AbstractContainerRequestValueFactory<WebSocketSessionContext>() {
public WebSocketSessionContext provide() {
Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal();
if (principal == null) {
throw new IllegalStateException("Cannot inject a custom principal into unauthenticated request");
}
if (!(principal instanceof WebSocketServletRequest.ContextPrincipal)) {
throw new IllegalArgumentException("Cannot inject a non-WebSocket AuthPrincipal into request");
}
return ((WebSocketServletRequest.ContextPrincipal)principal).getContext();
}
};
}
@Singleton
private static class WebSocketSessionInjectionResolver extends ParamInjectionResolver<WebSocketSession> {
public WebSocketSessionInjectionResolver() {
super(WebSocketSessionContextValueFactoryProvider.class);
} else if (WebSocketSessionContext.class.equals(parameter.getRawType())) {
return request -> new WebSocketSessionContainerRequestValueFactory(request).provide();
} else {
throw new IllegalArgumentException("Can't inject custom type");
}
}
public static class Binder extends AbstractBinder {
public Binder() {
}
public Binder() { }
@Override
protected void configure() {
bind(WebSocketSessionContextValueFactoryProvider.class).to(ValueFactoryProvider.class).in(Singleton.class);
bind(WebSocketSessionInjectionResolver.class).to(new TypeLiteral<InjectionResolver<WebSocketSession>>() {
}).in(Singleton.class);
bind(WebSocketSessionContextValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
}
}
}

View File

@@ -17,33 +17,30 @@
package org.whispersystems.websocket.setup;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.eclipse.jetty.server.RequestLog;
import org.glassfish.jersey.servlet.ServletContainer;
import org.glassfish.jersey.server.ResourceConfig;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import javax.servlet.http.HttpServlet;
import javax.validation.Validator;
import java.security.Principal;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.setup.JerseyContainerHolder;
import io.dropwizard.jersey.setup.JerseyEnvironment;
import io.dropwizard.setup.Environment;
public class WebSocketEnvironment {
public class WebSocketEnvironment<T extends Principal> {
private final JerseyContainerHolder jerseyServletContainer;
private final JerseyEnvironment jerseyEnvironment;
private final ResourceConfig jerseyConfig;
private final ObjectMapper objectMapper;
private final Validator validator;
private final RequestLog requestLog;
private final WebsocketRequestLog requestLog;
private final long idleTimeoutMillis;
private WebSocketAuthenticator authenticator;
private WebSocketMessageFactory messageFactory;
private WebSocketConnectListener connectListener;
private WebSocketAuthenticator<T> authenticator;
private WebSocketMessageFactory messageFactory;
private WebSocketConnectListener connectListener;
public WebSocketEnvironment(Environment environment, WebSocketConfiguration configuration) {
this(environment, configuration, 60000);
@@ -53,27 +50,24 @@ public class WebSocketEnvironment {
this(environment, configuration.getRequestLog().build("websocket"), idleTimeoutMillis);
}
public WebSocketEnvironment(Environment environment, RequestLog requestLog, long idleTimeoutMillis) {
DropwizardResourceConfig jerseyConfig = new DropwizardResourceConfig(environment.metrics());
this.objectMapper = environment.getObjectMapper();
this.validator = environment.getValidator();
this.requestLog = requestLog;
this.jerseyServletContainer = new JerseyContainerHolder(new ServletContainer(jerseyConfig) );
this.jerseyEnvironment = new JerseyEnvironment(jerseyServletContainer, jerseyConfig);
this.messageFactory = new ProtobufWebSocketMessageFactory();
this.idleTimeoutMillis = idleTimeoutMillis;
public WebSocketEnvironment(Environment environment, WebsocketRequestLog requestLog, long idleTimeoutMillis) {
this.jerseyConfig = new DropwizardResourceConfig(environment.metrics());
this.objectMapper = environment.getObjectMapper();
this.validator = environment.getValidator();
this.requestLog = requestLog;
this.messageFactory = new ProtobufWebSocketMessageFactory();
this.idleTimeoutMillis = idleTimeoutMillis;
}
public JerseyEnvironment jersey() {
return jerseyEnvironment;
public ResourceConfig jersey() {
return jerseyConfig;
}
public WebSocketAuthenticator getAuthenticator() {
public WebSocketAuthenticator<T> getAuthenticator() {
return authenticator;
}
public void setAuthenticator(WebSocketAuthenticator authenticator) {
public void setAuthenticator(WebSocketAuthenticator<T> authenticator) {
this.authenticator = authenticator;
}
@@ -85,7 +79,7 @@ public class WebSocketEnvironment {
return objectMapper;
}
public RequestLog getRequestLog() {
public WebsocketRequestLog getRequestLog() {
return requestLog;
}
@@ -93,10 +87,6 @@ public class WebSocketEnvironment {
return validator;
}
public HttpServlet getJerseyServletContainer() {
return (HttpServlet)jerseyServletContainer.getContainer();
}
public WebSocketMessageFactory getMessageFactory() {
return messageFactory;
}