Update to Dropwizard 5

Co-authored-by: Chris Eager <chris@signal.org>
This commit is contained in:
ravi-signal
2025-11-04 12:18:56 -06:00
committed by GitHub
parent 24f8f48a26
commit 4dbd564442
36 changed files with 703 additions and 664 deletions

View File

@@ -15,7 +15,6 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -35,7 +34,6 @@ public class BufferingInterceptorIntegrationTest {
environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
}
}

View File

@@ -6,7 +6,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -19,39 +18,40 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.EnumSet;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class ProvisioningTimeoutIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
@@ -77,9 +77,9 @@ public class ProvisioningTimeoutIntegrationTest {
CompletableFuture<String> provisioningAddressFuture = new CompletableFuture<>();
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
MessageProtos.ProvisioningAddress provisioningAddress =
@@ -92,7 +92,7 @@ public class ProvisioningTimeoutIntegrationTest {
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
super.onWebSocketBinary(payload, offset, length);
super.onWebSocketBinary(payload, callback);
}
}
@@ -106,21 +106,17 @@ public class ProvisioningTimeoutIntegrationTest {
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.setConnectListener(
new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet = environment.servlets()
.addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}

View File

@@ -9,8 +9,6 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
@@ -20,19 +18,20 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@@ -41,6 +40,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class WebsocketResourceProviderIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
@@ -72,9 +72,6 @@ public class WebsocketResourceProviderIntegrationTest {
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
@@ -85,15 +82,13 @@ public class WebsocketResourceProviderIntegrationTest {
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}

View File

@@ -15,8 +15,8 @@ import java.util.List;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;

View File

@@ -13,20 +13,17 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import java.io.IOException;
import java.net.InetAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@@ -35,14 +32,15 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@@ -55,6 +53,7 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class RemoteAddressFilterIntegrationTest {
private static final String WEBSOCKET_PREFIX = "/websocket";
@@ -131,7 +130,7 @@ class RemoteAddressFilterIntegrationTest {
}
}
private static class ClientEndpoint implements WebSocketListener {
public static class ClientEndpoint implements Session.Listener.AutoDemanding {
private final String requestPath;
private final CompletableFuture<byte[]> responseFuture;
@@ -145,22 +144,19 @@ class RemoteAddressFilterIntegrationTest {
}
@Override
public void onWebSocketConnect(final Session session) {
public void onWebSocketOpen(final Session session) {
final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
List.of("Accept: application/json"),
Optional.empty()).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP);
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
assert 200 == webSocketMessage.getResponseMessage().getStatus();
@@ -206,10 +202,6 @@ class RemoteAddressFilterIntegrationTest {
public void run(final Configuration configuration,
final Environment environment) throws Exception {
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
environment.jersey().register(new TestRemoteAddressController());
// WebSocket set up
@@ -220,15 +212,14 @@ class RemoteAddressFilterIntegrationTest {
webSocketEnvironment.jersey().register(new TestWebSocketController());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
webSocketEnvironment, TestPrincipal.class,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> {
container.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH, webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}

View File

@@ -27,7 +27,6 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import jakarta.annotation.Priority;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.NotAuthorizedException;
@@ -44,7 +43,6 @@ import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@@ -54,25 +52,28 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.security.auth.Subject;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.EventsHandler;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@@ -80,7 +81,8 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class MetricsHttpChannelListenerIntegrationTest {
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class MetricsHttpEventHandlerIntegrationTest {
private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP;
private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class);
@@ -90,7 +92,7 @@ class MetricsHttpChannelListenerIntegrationTest {
private static final AtomicReference<CountDownLatch> COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>();
private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(
MetricsHttpChannelListenerIntegrationTest.TestApplication.class);
MetricsHttpEventHandlerIntegrationTest.TestApplication.class);
@AfterEach
void teardown() {
@@ -111,9 +113,9 @@ class MetricsHttpChannelListenerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of(
MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
@@ -147,7 +149,7 @@ class MetricsHttpChannelListenerIntegrationTest {
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
@@ -158,11 +160,11 @@ class MetricsHttpChannelListenerIntegrationTest {
}
assertEquals(6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath)));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(expectedStatus))));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, expectedTagPath)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(expectedStatus))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
}
@@ -194,24 +196,19 @@ class MetricsHttpChannelListenerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of(
MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
client.connect(new AutoClosingWebSocketSessionListener(),
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest);
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
@@ -222,11 +219,11 @@ class MetricsHttpChannelListenerIntegrationTest {
}
assertEquals(6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
}
@@ -248,17 +245,16 @@ class MetricsHttpChannelListenerIntegrationTest {
public void run(final Configuration configuration,
final Environment environment) throws Exception {
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY,
mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);
MetricsHttpEventHandler.configure(environment, METER_REGISTRY, mock(ClientReleaseManager.class), Set.of("/v1/websocket"));
metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE));
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.lifecycle().addEventListener(new LifeCycle.Listener() {
@Override
public void lifeCycleStarting(final LifeCycle event) {
if (event instanceof Server server) {
server.setHandler(new TestListener(server.getHandler(), COUNT_DOWN_LATCH_FUTURE_REFERENCE));
}
}
});
environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter());
@@ -271,14 +267,15 @@ class MetricsHttpChannelListenerIntegrationTest {
webSocketEnvironment.jersey().register(new TestResource());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
webSocketEnvironment, TestPrincipal.class,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocket", webSocketServlet)
.addMapping("/v1/websocket");
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(),
(servletContext, container) -> {
container.addMapping("/v1/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
}
}
@@ -294,36 +291,23 @@ class MetricsHttpChannelListenerIntegrationTest {
}
/**
* A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on
* A simple listener to signal that {@link EventsHandler} has completed its work, since its onComplete() is on
* a different thread from the one that sends the response, creating a race condition between the listener and the
* test assertions
*/
static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener {
static class TestListener extends EventsHandler {
private final AtomicReference<CountDownLatch> completableFutureAtomicReference;
TestListener(AtomicReference<CountDownLatch> countDownLatchReference) {
TestListener(final Handler handler, AtomicReference<CountDownLatch> countDownLatchReference) {
super(handler);
this.completableFutureAtomicReference = countDownLatchReference;
}
@Override
public void onComplete(final Request request) {
public void onComplete(Request request, int status, HttpFields headers, Throwable failure) {
completableFutureAtomicReference.get().countDown();
}
@Override
public void beanAdded(final Container parent, final Object child) {
if (child instanceof Connector connector) {
connector.addBean(this);
}
}
@Override
public void beanRemoved(final Container parent, final Object child) {
}
}
@Path("/v1/test")
@@ -365,4 +349,11 @@ class MetricsHttpChannelListenerIntegrationTest {
}
}
public static class AutoClosingWebSocketSessionListener implements Session.Listener.AutoDemanding {
@Override
public void onWebSocketOpen(final Session session) {
session.close(1000, "OK", Callback.NOOP);
}
}
}

View File

@@ -18,13 +18,18 @@ import com.google.common.net.HttpHeaders;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.util.Collections;
import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach;
@@ -34,7 +39,8 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
class MetricsHttpChannelListenerTest {
class MetricsHttpEventHandlerTest {
private final static String USER_AGENT = "Signal-Android/6.53.7 (Android 8.1)";
private MeterRegistry meterRegistry;
private Counter requestCounter;
@@ -42,7 +48,7 @@ class MetricsHttpChannelListenerTest {
private Counter responseBytesCounter;
private Counter requestBytesCounter;
private ClientReleaseManager clientReleaseManager;
private MetricsHttpChannelListener listener;
private MetricsHttpEventHandler listener;
@BeforeEach
void setup() {
@@ -52,26 +58,27 @@ class MetricsHttpChannelListenerTest {
responseBytesCounter = mock(Counter.class);
requestBytesCounter = mock(Counter.class);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestsByVersionCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(responseBytesCounter);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
clientReleaseManager = mock(ClientReleaseManager.class);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
listener = new MetricsHttpEventHandler(null, meterRegistry, clientReleaseManager, Set.of("/test"));
}
@Test
@ParameterizedTest
@ValueSource(booleans = {true, false})
@SuppressWarnings("unchecked")
void testRequests() {
void testRequests(boolean pathFromFilter) {
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
@@ -81,28 +88,39 @@ class MetricsHttpChannelListenerTest {
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)");
final HttpFields.Mutable requestHeaders = HttpFields.build();
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri);
if (pathFromFilter) {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
} else {
when(request.setAttribute(eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), any())).thenAnswer(invocation -> {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(invocation.getArgument(1));
return null;
});
}
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(response.getContentCount()).thenReturn(1024L);
when(request.getResponse()).thenReturn(response);
when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
listener.onComplete(request);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true));
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024));
listener.onComplete(request, statusCode, requestHeaders, null);
verify(requestCounter).increment();
verify(responseBytesCounter).increment(1024L);
verify(requestBytesCounter).increment(512L);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagCaptor.getValue()) {
@@ -110,11 +128,11 @@ class MetricsHttpChannelListenerTest {
}
assertEquals(6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, path)));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, method)));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode))));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, path)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, method)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(statusCode))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
}
@@ -133,23 +151,22 @@ class MetricsHttpChannelListenerTest {
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
final HttpFields.Mutable requestHeaders = HttpFields.build();
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri);
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(response.getContentCount()).thenReturn(1024L);
when(request.getResponse()).thenReturn(response);
when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
listener.onComplete(request);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true));
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024));
listener.onComplete(request, statusCode, requestHeaders, null);
if (versionActive) {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
tags.clear();
@@ -163,7 +180,38 @@ class MetricsHttpChannelListenerTest {
} else {
verifyNoInteractions(requestsByVersionCounter);
}
}
@Test
void testResponseFilterSetsRequestInfo() {
final ContainerRequest request = mock(ContainerRequest.class);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getMethod()).thenReturn("GET");
when(request.getHeaders()).thenReturn(null);
when(request.getUriInfo()).thenReturn(extendedUriInfo);
when(request.getHeaderString(HttpHeaders.USER_AGENT)).thenReturn(USER_AGENT);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
verify(request).setProperty(
eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME),
eq(new MetricsHttpEventHandler.RequestInfo("/test", "GET", USER_AGENT)));
}
@Test
void testResponseFilterModifiesRequestInfo() {
final MetricsHttpEventHandler.RequestInfo requestInfo =
new MetricsHttpEventHandler.RequestInfo("unknown", "POST", USER_AGENT);
final ContainerRequest request = mock(ContainerRequest.class);
when(request.getProperty(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)).thenReturn(requestInfo);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getUriInfo()).thenReturn(extendedUriInfo);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
assertEquals(new MetricsHttpEventHandler.RequestInfo("/test", "POST", USER_AGENT), requestInfo);
}
}

View File

@@ -34,10 +34,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
@@ -160,11 +159,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)");
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)")));
@@ -176,15 +173,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -228,11 +225,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@@ -242,15 +237,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -297,11 +292,9 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@@ -311,15 +304,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);

View File

@@ -149,15 +149,13 @@ class TlsCertificateExpirationUtilTest {
@Test
void test() throws Exception {
try (Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64)) {
final Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64);
final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
final Map<String, Instant> expected = Map.of(
"localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION);
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD));
}
final Map<String, Instant> expected = Map.of(
"localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION);
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD));
}
}

View File

@@ -4,14 +4,6 @@
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
@@ -19,8 +11,14 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
public class TestWebsocketListener implements WebSocketListener {
public class TestWebsocketListener implements Session.Listener.AutoDemanding {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
@@ -34,7 +32,7 @@ public class TestWebsocketListener implements WebSocketListener {
@Override
public void onWebSocketConnect(final Session session) {
public void onWebSocketOpen(final Session session) {
started.complete(session);
}
@@ -63,19 +61,15 @@ public class TestWebsocketListener implements WebSocketListener {
responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP);
return future;
});
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());

View File

@@ -6,12 +6,9 @@
package org.whispersystems.textsecuregcm.util.jetty;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Path;
import java.util.Base64;
import org.eclipse.jetty.util.resource.Resource;
@@ -30,15 +27,16 @@ public class TestResource extends Resource {
}
@Override
public boolean isContainedIn(final Resource r) throws MalformedURLException {
return false;
public Path getPath() {
return null;
}
@Override
public void close() {
public InputStream newInputStream() {
return new ByteArrayInputStream(data);
}
@Override
public boolean exists() {
return true;
@@ -50,13 +48,13 @@ public class TestResource extends Resource {
}
@Override
public long lastModified() {
return 0;
public boolean isReadable() {
return true;
}
@Override
public long length() {
return 0;
return data.length;
}
@Override
@@ -64,43 +62,19 @@ public class TestResource extends Resource {
return null;
}
@Override
public File getFile() throws IOException {
return null;
}
@Override
public String getName() {
return name;
}
@Override
public InputStream getInputStream() throws IOException {
return new ByteArrayInputStream(data);
public String getFileName() {
return "";
}
@Override
public ReadableByteChannel getReadableByteChannel() throws IOException {
public Resource resolve(final String subUriPath) {
return null;
}
@Override
public boolean delete() throws SecurityException {
return false;
}
@Override
public boolean renameTo(final Resource dest) throws SecurityException {
return false;
}
@Override
public String[] list() {
return new String[]{name};
}
@Override
public Resource addPath(final String path) throws IOException, MalformedURLException {
return this;
}
}

View File

@@ -39,10 +39,9 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.ServerProperties;
@@ -143,12 +142,12 @@ class LoggingUnhandledExceptionMapperTest {
WebSocketResourceProvider<TestPrincipal> provider = createWebsocketProvider(userAgentHeader, session,
responseFuture::complete);
provider.onWebSocketConnect(session);
provider.onWebSocketOpen(session);
byte[] message = new ProtobufWebSocketMessageFactory()
.createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP);
responseFuture.get(1, TimeUnit.SECONDS);
@@ -179,15 +178,13 @@ class LoggingUnhandledExceptionMapperTest {
TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
doAnswer(answer -> {
responseHandler.accept(answer.getArgument(0, ByteBuffer.class));
return null;
}).when(remoteEndpoint).sendBytes(any(), any(WriteCallback.class));
}).when(session).sendBinary(any(ByteBuffer.class), any(Callback.class));
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader);
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader)));

View File

@@ -19,7 +19,7 @@ import java.util.Optional;
import java.util.UUID;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
@@ -44,7 +44,7 @@ class WebSocketAccountAuthenticatorTest {
private AccountAuthenticator accountAuthenticator;
private UpgradeRequest upgradeRequest;
private JettyServerUpgradeRequest upgradeRequest;
@BeforeEach
void setUp() {
@@ -56,7 +56,7 @@ class WebSocketAccountAuthenticatorTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
upgradeRequest = mock(UpgradeRequest.class);
upgradeRequest = mock(JettyServerUpgradeRequest.class);
}
@ParameterizedTest