Update to dropwizard 2.x

This commit is contained in:
Moxie Marlinspike
2020-03-06 17:39:31 -08:00
parent 69285f28ad
commit 009f81a9a6
45 changed files with 1782 additions and 3011 deletions

View File

@@ -1,64 +0,0 @@
package org.whispersystems.websocket;
import org.eclipse.jetty.server.AbstractNCSARequestLog;
import org.eclipse.jetty.server.NCSARequestLog;
import org.eclipse.jetty.server.RequestLog;
import org.eclipse.jetty.util.component.AbstractLifeCycle;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.junit.Test;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.servlet.LoggableRequest;
import org.whispersystems.websocket.servlet.LoggableResponse;
import org.whispersystems.websocket.servlet.WebSocketServletRequest;
import org.whispersystems.websocket.servlet.WebSocketServletResponse;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.HashMap;
import java.util.Optional;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class LoggableRequestResponseTest {
@Test
public void testLogging() {
NCSARequestLog requestLog = new EnabledNCSARequestLog();
WebSocketClient webSocketClient = mock(WebSocketClient.class );
WebSocketRequestMessage requestMessage = mock(WebSocketRequestMessage.class);
ServletContext servletContext = mock(ServletContext.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class );
WebSocketMessageFactory messageFactory = mock(WebSocketMessageFactory.class);
when(requestMessage.getVerb()).thenReturn("GET");
when(requestMessage.getBody()).thenReturn(Optional.empty());
when(requestMessage.getHeaders()).thenReturn(new HashMap<>());
when(requestMessage.getPath()).thenReturn("/api/v1/test");
when(requestMessage.getRequestId()).thenReturn(1L);
when(requestMessage.hasRequestId()).thenReturn(true);
WebSocketSessionContext sessionContext = new WebSocketSessionContext (webSocketClient );
HttpServletRequest servletRequest = new WebSocketServletRequest (sessionContext, requestMessage, servletContext);
HttpServletResponse servletResponse = new WebSocketServletResponse(remoteEndpoint, 1, messageFactory );
LoggableRequest loggableRequest = new LoggableRequest (servletRequest );
LoggableResponse loggableResponse = new LoggableResponse(servletResponse);
requestLog.log(loggableRequest, loggableResponse);
}
private class EnabledNCSARequestLog extends NCSARequestLog {
@Override
public boolean isEnabled() {
return true;
}
}
}

View File

@@ -4,16 +4,21 @@ package org.whispersystems.websocket;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig;
import org.junit.Test;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
import javax.security.auth.Subject;
import javax.servlet.ServletException;
import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import io.dropwizard.jersey.setup.JerseyEnvironment;
import io.dropwizard.jersey.DropwizardResourceConfig;
import static org.junit.Assert.*;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
@@ -21,8 +26,8 @@ import static org.mockito.Mockito.*;
public class WebSocketResourceProviderFactoryTest {
@Test
public void testUnauthorized() throws ServletException, AuthenticationException, IOException {
JerseyEnvironment jerseyEnvironment = mock(JerseyEnvironment.class );
public void testUnauthorized() throws AuthenticationException, IOException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class);
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
@@ -32,7 +37,7 @@ public class WebSocketResourceProviderFactoryTest {
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@@ -42,19 +47,19 @@ public class WebSocketResourceProviderFactoryTest {
@Test
public void testValidAuthorization() throws AuthenticationException, ServletException {
JerseyEnvironment jerseyEnvironment = mock(JerseyEnvironment.class );
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class);
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class);
Session session = mock(Session.class );
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
Session session = mock(Session.class );
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection);
@@ -67,7 +72,51 @@ public class WebSocketResourceProviderFactoryTest {
assertEquals(((WebSocketResourceProvider)connection).getContext().getAuthenticated(), account);
}
private static class Account {}
@Test
public void testErrorAuthorization() throws AuthenticationException, ServletException, IOException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure"));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
verify(response).sendError(eq(500), eq("Failure"));
verify(authenticator).authenticate(eq(request));
}
@Test
public void testConfigure() {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class );
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));
}
private static class Account implements Principal {
@Override
public String getName() {
return null;
}
@Override
public boolean implies(Subject subject) {
return false;
}
}
}

View File

@@ -1,60 +1,95 @@
package org.whispersystems.websocket;
import org.eclipse.jetty.server.RequestLog;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.glassfish.jersey.server.ResourceConfig;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.mockito.stubbing.Answer;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import javax.validation.Valid;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
public class WebSocketResourceProviderTest {
@Test
public void testOnConnect() throws AuthenticationException, IOException {
HttpServlet contextHandler = mock(HttpServlet.class);
WebSocketAuthenticator<String> authenticator = mock(WebSocketAuthenticator.class);
RequestLog requestLog = mock(RequestLog.class);
WebSocketResourceProvider provider = new WebSocketResourceProvider(contextHandler, requestLog,
null,
new ProtobufWebSocketMessageFactory(),
Optional.empty(),
30000);
public void testOnConnect() {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class );
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class );
WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
applicationHandler, requestLog,
new TestPrincipal("fooz"),
new ProtobufWebSocketMessageFactory(),
Optional.of(connectListener),
30000 );
Session session = mock(Session.class );
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(authenticator.authenticate(request)).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of("fooz"), true));
provider.onWebSocketConnect(session);
verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
ArgumentCaptor<WebSocketSessionContext> contextArgumentCaptor = ArgumentCaptor.forClass(WebSocketSessionContext.class);
verify(connectListener).onWebSocketConnect(contextArgumentCaptor.capture());
assertThat(contextArgumentCaptor.getValue().getAuthenticated(TestPrincipal.class).getName()).isEqualTo("fooz");
}
@Test
public void testRouteMessage() throws Exception {
HttpServlet servlet = mock(HttpServlet.class );
WebSocketAuthenticator<String> authenticator = mock(WebSocketAuthenticator.class);
RequestLog requestLog = mock(RequestLog.class );
WebSocketResourceProvider provider = new WebSocketResourceProvider(servlet, requestLog, Optional.of((WebSocketAuthenticator)authenticator), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
public void testMockedRouteMessageSuccess() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class );
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -62,7 +97,33 @@ public class WebSocketResourceProviderTest {
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(authenticator.authenticate(request)).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of("foo"), true));
ContainerResponse response = mock(ContainerResponse.class);
when(response.getStatus()).thenReturn(200);
when(response.getStatusInfo()).thenReturn(new Response.StatusType() {
@Override
public int getStatusCode() {
return 200;
}
@Override
public Response.Status.Family getFamily() {
return Response.Status.Family.SUCCESSFUL;
}
@Override
public String getReasonPhrase() {
return "OK";
}
});
ArgumentCaptor<OutputStream> responseOutputStream = ArgumentCaptor.forClass(OutputStream.class);
when(applicationHandler.apply(any(ContainerRequest.class), responseOutputStream.capture()))
.thenAnswer((Answer<CompletableFuture<ContainerResponse>>) invocation -> {
responseOutputStream.getValue().write("hello world!".getBytes());
return CompletableFuture.completedFuture(response);
});
provider.onWebSocketConnect(session);
@@ -70,21 +131,567 @@ public class WebSocketResourceProviderTest {
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<String>(), Optional.of("hello world!".getBytes())).toByteArray();
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<HttpServletRequest> requestCaptor = ArgumentCaptor.forClass(HttpServletRequest.class);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class );
verify(servlet).service(requestCaptor.capture(), any(HttpServletResponse.class));
verify(applicationHandler).apply(requestCaptor.capture(), any(OutputStream.class));
HttpServletRequest bundledRequest = requestCaptor.getValue();
ContainerRequest bundledRequest = requestCaptor.getValue();
byte[] expected = new byte[bundledRequest.getInputStream().available()];
int read = bundledRequest.getInputStream().read(expected);
assertThat(bundledRequest.getRequest().getMethod()).isEqualTo("GET");
assertThat(bundledRequest.getBaseUri().toString()).isEqualTo("/");
assertThat(bundledRequest.getPath(false)).isEqualTo("bar");
assertThat(read).isEqualTo(expected.length);
assertThat(new String(expected)).isEqualTo("hello world!");
verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response));
verify(remoteEndpoint).sendBytesByFuture(responseCaptor.capture());
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array());
assertThat(responseMessageContainer.getResponse().getId()).isEqualTo(111L);
assertThat(responseMessageContainer.getResponse().getStatus()).isEqualTo(200);
assertThat(responseMessageContainer.getResponse().getMessage()).isEqualTo("OK");
assertThat(responseMessageContainer.getResponse().getBody()).isEqualTo(ByteString.copyFrom("hello world!".getBytes()));
}
@Test
public void testMockedRouteMessageFailure() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class );
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn(CompletableFuture.failedFuture(new IllegalStateException("foo")));
provider.onWebSocketConnect(session);
verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
verify(applicationHandler).apply(requestCaptor.capture(), any(OutputStream.class));
ContainerRequest bundledRequest = requestCaptor.getValue();
assertThat(bundledRequest.getRequest().getMethod()).isEqualTo("GET");
assertThat(bundledRequest.getBaseUri().toString()).isEqualTo("/");
assertThat(bundledRequest.getPath(false)).isEqualTo("bar");
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseCaptor.capture());
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array());
assertThat(responseMessageContainer.getResponse().getStatus()).isEqualTo(500);
assertThat(responseMessageContainer.getResponse().getMessage()).isEqualTo("Error response");
assertThat(responseMessageContainer.getResponse().hasBody()).isFalse();
}
@Test
public void testActualRouteMessageSuccess() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody()).isEqualTo(ByteString.copyFrom("Hello!".getBytes()));
}
@Test
public void testActualRouteMessageNotFound() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(404);
assertThat(response.getMessage()).isEqualTo("Not Found");
assertThat(response.hasBody()).isFalse();
}
@Test
public void testActualRouteMessageAuthorized() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody().toStringUtf8()).isEqualTo("World: authorizedUserName");
}
@Test
public void testActualRouteMessageUnauthorized() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(401);
assertThat(response.hasBody()).isFalse();
}
@Test
public void testActualRouteMessageOptionalAuthorizedPresent() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody().toStringUtf8()).isEqualTo("World: something");
}
@Test
public void testActualRouteMessageOptionalAuthorizedEmpty() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody().toStringUtf8()).isEqualTo("Empty world");
}
@Test
public void testActualRouteMessagePutAuthenticatedEntity() throws InvalidProtocolBufferException, JsonProcessingException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody().toStringUtf8()).isEqualTo("gooduser:testparam:mykey:1001");
}
@Test
public void testActualRouteMessagePutAuthenticatedBadEntity() throws InvalidProtocolBufferException, JsonProcessingException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.getMessage()).isEqualTo("Bad Request");
assertThat(response.hasBody()).isFalse();
}
@Test
public void testActualRouteMessageExceptionMapping() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new TestExceptionMapper());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(1337);
assertThat(response.hasBody()).isFalse();
}
@Test
public void testActualRouteSessionContextInjection() throws InvalidProtocolBufferException {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(new TestResource());
resourceConfig.register(new TestExceptionMapper());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000);
Session session = mock(Session.class );
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor);
assertThat(requestMessage.getVerb()).isEqualTo("GET");
assertThat(requestMessage.getPath()).isEqualTo("/v1/miccheck");
assertThat(requestMessage.getBody().toStringUtf8()).isEqualTo("smert ze smert");
byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK", new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray();
provider.onWebSocketBinary(clientResponse, 0, clientResponse.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
assertThat(response.getId()).isEqualTo(111L);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getMessage()).isEqualTo("OK");
assertThat(response.getBody().toStringUtf8()).isEqualTo("my response");
}
private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException {
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
}
private SubProtocol.WebSocketRequestMessage getRequest(ArgumentCaptor<ByteBuffer> requestCaptor) throws InvalidProtocolBufferException {
return SubProtocol.WebSocketMessage.parseFrom(requestCaptor.getValue().array()).getRequest();
}
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
public static class TestException extends Exception {
public TestException(String message) {
super(message);
}
}
@Provider
public static class TestExceptionMapper implements ExceptionMapper<TestException> {
@Override
public Response toResponse(TestException exception) {
return Response.status(1337).build();
}
}
@Path("/v1/test")
public static class TestResource {
@GET
@Path("/hello")
public String testGetHello() {
return "Hello!";
}
@GET
@Path("/world")
public String testAuthorizedHello(@Auth TestPrincipal user) {
if (user == null) throw new AssertionError();
return "World: " + user.getName();
}
@GET
@Path("/optional")
public String testAuthorizedHello(@Auth Optional<TestPrincipal> user) {
if (user.isPresent()) return "World: " + user.get().getName();
else return "Empty world";
}
@PUT
@Path("/some/{param}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public Response testSet(@Auth TestPrincipal user, @PathParam ("param") String param, @Valid TestEntity entity) {
return Response.ok(user.name + ":" + param + ":" + entity.key + ":" + entity.value).build();
}
@GET
@Path("/exception/map")
public Response testExceptionMapping() throws TestException {
throw new TestException("I'd like to map this");
}
@GET
@Path("/keepalive")
public CompletableFuture<Response> testContextInjection(@WebSocketSession WebSocketSessionContext context) {
if (context == null) {
throw new AssertionError();
}
return context.getClient()
.sendRequest("GET", "/v1/miccheck", new LinkedList<>(), Optional.of("smert ze smert".getBytes()))
.thenApply(response -> Response.ok().entity(new String(response.getBody().get())).build());
}
public static class TestEntity {
public TestEntity(String key, long value) {
this.key = key;
this.value = value;
}
public TestEntity() {
}
@JsonProperty
@NotEmpty
private String key;
@JsonProperty
@Min(100)
private long value;
}
}
}

View File

@@ -0,0 +1,120 @@
package org.whispersystems.websocket.logging;
import org.glassfish.jersey.internal.MapPropertiesDelegate;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.junit.Test;
import org.whispersystems.websocket.WebSocketSecurityContext;
import org.whispersystems.websocket.session.ContextPrincipal;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.ws.rs.core.Response;
import java.io.ByteArrayOutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.core.OutputStreamAppender;
import ch.qos.logback.core.spi.DeferredProcessingAware;
import io.dropwizard.logging.AbstractOutputStreamAppenderFactory;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
public class WebSocketRequestLogTest {
@Test
public void testLogLineWithoutHeaders() throws InterruptedException {
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
ListAppender<WebsocketEvent> listAppender = new ListAppender<>();
WebsocketRequestLoggerFactory requestLoggerFactory = new WebsocketRequestLoggerFactory();
requestLoggerFactory.appenders = List.of(new ListAppenderFactory<>(listAppender));
WebsocketRequestLog requestLog = requestLoggerFactory.build("test-logger");
ContainerRequest request = new ContainerRequest (null, URI.create("/v1/test"), "GET", new WebSocketSecurityContext(new ContextPrincipal(sessionContext)), new MapPropertiesDelegate(new HashMap<>()), null);
ContainerResponse response = new ContainerResponse(request, Response.ok("My response body").build());
requestLog.log("123.456.789.123", request, response);
listAppender.waitForListSize(1);
assertThat(listAppender.list.size()).isEqualTo(1);
String loggedLine = new String(listAppender.outputStream.toByteArray());
assertThat(loggedLine.matches("123\\.456\\.789\\.123 \\- \\- \\[[0-9]{2}\\/[a-zA-Z]{3}\\/[0-9]{4}:[0-9]{2}:[0-9]{2}:[0-9]{2} \\-[0-9]{4}\\] \"GET \\/v1\\/test WS\" 200 \\- \"\\-\" \"\\-\"\n")).isTrue();
}
@Test
public void testLogLineWithHeaders() throws InterruptedException {
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
ListAppender<WebsocketEvent> listAppender = new ListAppender<>();
WebsocketRequestLoggerFactory requestLoggerFactory = new WebsocketRequestLoggerFactory();
requestLoggerFactory.appenders = List.of(new ListAppenderFactory<>(listAppender));
WebsocketRequestLog requestLog = requestLoggerFactory.build("test-logger");
ContainerRequest request = new ContainerRequest (null, URI.create("/v1/test"), "GET", new WebSocketSecurityContext(new ContextPrincipal(sessionContext)), new MapPropertiesDelegate(new HashMap<>()), null);
request.header("User-Agent", "SmertZeSmert");
request.header("Referer", "https://moxie.org");
ContainerResponse response = new ContainerResponse(request, Response.ok("My response body").build());
requestLog.log("123.456.789.123", request, response);
listAppender.waitForListSize(1);
assertThat(listAppender.list.size()).isEqualTo(1);
String loggedLine = new String(listAppender.outputStream.toByteArray());
assertThat(loggedLine.matches("123\\.456\\.789\\.123 \\- \\- \\[[0-9]{2}\\/[a-zA-Z]{3}\\/[0-9]{4}:[0-9]{2}:[0-9]{2}:[0-9]{2} \\-[0-9]{4}\\] \"GET \\/v1\\/test WS\" 200 \\- \"https://moxie.org\" \"SmertZeSmert\"\n")).isTrue();
System.out.println(listAppender.list.get(0));
System.out.println(new String(listAppender.outputStream.toByteArray()));
}
private static class ListAppenderFactory<T extends DeferredProcessingAware> extends AbstractOutputStreamAppenderFactory<T> {
private final ListAppender<T> listAppender;
public ListAppenderFactory(ListAppender<T> listAppender) {
this.listAppender = listAppender;
}
@Override
protected OutputStreamAppender<T> appender(LoggerContext context) {
listAppender.setContext(context);
return listAppender;
}
}
private static class ListAppender<E> extends OutputStreamAppender<E> {
public final List<E> list = new ArrayList<E>();
public final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
protected void append(E e) {
super.append(e);
synchronized (list) {
list.add(e);
list.notifyAll();
}
}
@Override
public void start() {
setOutputStream(outputStream);
super.start();
}
public void waitForListSize(int size) throws InterruptedException {
synchronized (list) {
while (list.size() < size) {
list.wait(5000);
}
}
}
}
}