From 4dbd56444240e280fbc979c449aec57f3bcaaedd Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:18:56 -0600 Subject: [PATCH] Update to Dropwizard 5 Co-authored-by: Chris Eager --- pom.xml | 25 +- service/pom.xml | 8 +- .../textsecuregcm/WhisperServerService.java | 60 ++--- ...ceAuthenticatedWebSocketUpgradeFilter.java | 5 +- .../textsecuregcm/filters/PriorityFilter.java | 82 ++++++ .../filters/RemoteAddressFilter.java | 4 - .../MetricsApplicationEventListener.java | 4 +- .../metrics/MetricsHttpChannelListener.java | 192 -------------- .../metrics/MetricsHttpEventHandler.java | 245 ++++++++++++++++++ .../metrics/MetricsRequestEventListener.java | 4 +- .../metrics/TlsCertificateExpirationUtil.java | 4 +- .../WebSocketAccountAuthenticator.java | 8 +- .../BufferingInterceptorIntegrationTest.java | 2 - .../ProvisioningTimeoutIntegrationTest.java | 32 +-- ...socketResourceProviderIntegrationTest.java | 23 +- ...thenticatedWebSocketUpgradeFilterTest.java | 4 +- .../RemoteAddressFilterIntegrationTest.java | 39 ++- ...tricsHttpEventHandlerIntegrationTest.java} | 119 ++++----- ....java => MetricsHttpEventHandlerTest.java} | 114 +++++--- .../MetricsRequestEventListenerTest.java | 27 +- .../TlsCertificateExpirationUtilTest.java | 14 +- .../tests/util/TestWebsocketListener.java | 28 +- .../util/jetty/TestResource.java | 50 +--- .../LoggingUnhandledExceptionMapperTest.java | 11 +- .../WebSocketAccountAuthenticatorTest.java | 6 +- websocket-resources/pom.xml | 12 +- .../websocket/WebSocketClient.java | 15 +- .../websocket/WebSocketResourceProvider.java | 22 +- .../WebSocketResourceProviderFactory.java | 24 +- .../AuthenticatedWebSocketUpgradeFilter.java | 4 +- .../auth/WebSocketAuthenticator.java | 4 +- .../messages/WebSocketMessageFactory.java | 17 +- .../protobuf/ProtobufWebSocketMessage.java | 5 +- .../ProtobufWebSocketMessageFactory.java | 5 +- .../WebSocketResourceProviderFactoryTest.java | 34 +-- .../WebSocketResourceProviderTest.java | 115 +++----- 36 files changed, 703 insertions(+), 664 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java rename service/src/test/java/org/whispersystems/textsecuregcm/metrics/{MetricsHttpChannelListenerIntegrationTest.java => MetricsHttpEventHandlerIntegrationTest.java} (73%) rename service/src/test/java/org/whispersystems/textsecuregcm/metrics/{MetricsHttpChannelListenerTest.java => MetricsHttpEventHandlerTest.java} (50%) diff --git a/pom.xml b/pom.xml index b8645eabb..57055f08f 100644 --- a/pom.xml +++ b/pom.xml @@ -41,7 +41,7 @@ 3.44.0 1.14.1 2.20.0 - 4.0.16 + 5.0.0 1.1.14 + + org.ow2.asm + asm-commons + + + org.ow2.asm + asm-tree + + + org.ow2.asm + asm + + org.foundationdb @@ -257,12 +272,6 @@ commons-logging 1.3.5 - - org.ow2.asm - asm - 9.8 - test - com.stripe stripe-java @@ -360,7 +369,7 @@ org.wiremock - wiremock + wiremock-jetty12 3.13.1 test diff --git a/service/pom.xml b/service/pom.xml index edfcebcb2..be8e95bee 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -237,15 +237,15 @@ org.eclipse.jetty.websocket - websocket-jetty-api + jetty-websocket-jetty-api - org.eclipse.jetty - jetty-servlets + org.eclipse.jetty.ee10 + jetty-ee10-servlets org.eclipse.jetty.websocket - websocket-jetty-client + jetty-websocket-jetty-client test diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index e735ded3b..2fe6e4ec1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -37,15 +37,12 @@ import io.netty.resolver.ResolvedAddressTypes; import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolverBuilder; import jakarta.servlet.DispatcherType; -import jakarta.servlet.Filter; -import jakarta.servlet.ServletRegistration; import java.io.ByteArrayInputStream; import java.net.InetSocketAddress; import java.net.http.HttpClient; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; import java.util.List; @@ -60,9 +57,9 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.function.Function; import java.util.stream.Stream; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.core.WebSocketComponents; import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ServerProperties; import org.signal.i18n.HeaderControlledResourceBundleLookup; import org.signal.libsignal.zkgroup.GenericServerSecretParams; @@ -136,6 +133,7 @@ import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.currency.FixerClient; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; @@ -184,7 +182,7 @@ import org.whispersystems.textsecuregcm.metrics.BackupMetrics; import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener; -import org.whispersystems.textsecuregcm.metrics.MetricsHttpChannelListener; +import org.whispersystems.textsecuregcm.metrics.MetricsHttpEventHandler; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; @@ -911,17 +909,6 @@ public class WhisperServerService extends Application filters = new ArrayList<>(); - filters.add(remoteDeprecationFilter); - filters.add(new RemoteAddressFilter()); - filters.add(new TimestampResponseFilter()); - - for (Filter filter : filters) { - environment.servlets() - .addFilter(filter.getClass().getSimpleName(), filter) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); - } - if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) { environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(), new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(), @@ -938,9 +925,7 @@ public class WhisperServerService extends Application { - if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) { - WebSocketComponents components = - WebSocketServerComponents.getWebSocketComponents(environment.getApplicationContext().getServletContext()); - components.getExtensionRegistry().unregister("permessage-deflate"); - } - }); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(), - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + webSocketEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>( - provisioningEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(), - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + provisioningEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); - ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), + (servletContext, container) -> { + container.addMapping(websocketServletPath, webSocketServlet); + container.addMapping(provisioningWebsocketServletPath, provisioningServlet); - websocket.addMapping(websocketServletPath); - websocket.setAsyncSupported(true); + PriorityFilter.ensureFilter(servletContext, new TimestampResponseFilter()); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + PriorityFilter.ensureFilter(servletContext, remoteDeprecationFilter); - provisioning.addMapping(provisioningWebsocketServletPath); - provisioning.setAsyncSupported(true); + container.setMaxBinaryMessageSize(config.getWebSocketConfiguration().getMaxBinaryMessageSize()); + container.setMaxTextMessageSize(config.getWebSocketConfiguration().getMaxTextMessageSize()); + + + if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) { + WebSocketComponents components = + WebSocketServerComponents.getWebSocketComponents(environment.getApplicationContext()); + components.getExtensionRegistry().unregister("permessage-deflate"); + } + }); environment.admin().addTask(new SetRequestLoggingEnabledTask()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java index 38ff7d8ef..e2853497c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java @@ -10,10 +10,9 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Optional; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java new file mode 100644 index 000000000..03070cf00 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/PriorityFilter.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; +import jakarta.servlet.ServletContext; +import java.util.EnumSet; +import java.util.Objects; +import org.eclipse.jetty.ee10.servlet.FilterHolder; +import org.eclipse.jetty.ee10.servlet.FilterMapping; +import org.eclipse.jetty.ee10.servlet.ServletContextHandler; +import org.eclipse.jetty.ee10.servlet.ServletHandler; +import org.eclipse.jetty.server.handler.ContextHandler; +import org.eclipse.jetty.util.component.LifeCycle; + +public class PriorityFilter { + + private PriorityFilter() {} + + private static FilterHolder getFilter(ServletContext servletContext, final Class filterClass) { + final ContextHandler contextHandler = Objects.requireNonNull(ServletContextHandler.getServletContextHandler(servletContext)); + final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class); + return servletHandler.getFilter(filterClass.getName()); + } + + /** + * Ensure a filter is available on the provided ServletContext, a new filter will added if one does not already + * exist. + *

+ * If a new filter is added, it will be added before all other filters. + *

+ * Modeled after {@link org.eclipse.jetty.ee10.websocket.servlet.WebSocketUpgradeFilter#ensureFilter(ServletContext)}, + * since its use of {@link org.eclipse.jetty.ee10.servlet.ServletHandler#prependFilter(FilterHolder)} is what makes + * this necessary. + */ + public static void ensureFilter(final ServletContext servletContext, final Filter filter) { + FilterHolder existingFilter = getFilter(servletContext, filter.getClass()); + if (existingFilter != null) { + return; + } + + final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext); + final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class); + + final String pathSpec = "/*"; + final FilterHolder holder = new FilterHolder(filter); + holder.setName(filter.getClass().getName()); + holder.setAsyncSupported(true); + + final FilterMapping mapping = new FilterMapping(); + mapping.setFilterName(holder.getName()); + mapping.setPathSpec(pathSpec); + mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST)); + + // Add as the first filter in the list. + servletHandler.prependFilter(holder); + servletHandler.prependFilterMapping(mapping); + + // If we create the filter we must also make sure it is removed if the context is stopped. + contextHandler.addEventListener(new LifeCycle.Listener() + { + @Override + public void lifeCycleStopping(LifeCycle event) + { + servletHandler.removeFilterHolder(holder); + servletHandler.removeFilterMapping(mapping); + contextHandler.removeEventListener(this); + } + + @Override + public String toString() + { + return String.format("%sCleanupListener", filter.getClass().getSimpleName()); + } + }); + + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java index c1ea020b6..3b8090e66 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java @@ -26,10 +26,6 @@ public class RemoteAddressFilter implements Filter { public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress"; private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class); - - public RemoteAddressFilter() { - } - @Override public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws ServletException, IOException { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java index 5c58de79a..2d93ec012 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsApplicationEventListener.java @@ -14,7 +14,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; /** * Delegates request events to a listener that captures and reports request-level metrics. * - * @see MetricsHttpChannelListener + * @see MetricsHttpEventHandler * @see MetricsRequestEventListener */ public class MetricsApplicationEventListener implements ApplicationEventListener { @@ -23,7 +23,7 @@ public class MetricsApplicationEventListener implements ApplicationEventListener public MetricsApplicationEventListener(final TrafficSource trafficSource, final ClientReleaseManager clientReleaseManager) { if (trafficSource == TrafficSource.HTTP) { - throw new IllegalArgumentException("Use " + MetricsHttpChannelListener.class.getName() + " for HTTP traffic"); + throw new IllegalArgumentException("Use " + MetricsHttpEventHandler.class.getName() + " for HTTP traffic"); } this.metricsRequestEventListener = new MetricsRequestEventListener(trafficSource, clientReleaseManager); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java deleted file mode 100644 index 479870b38..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.metrics; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.net.HttpHeaders; -import io.dropwizard.core.setup.Environment; -import io.micrometer.core.instrument.MeterRegistry; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Tags; -import jakarta.ws.rs.container.ContainerRequestContext; -import jakarta.ws.rs.container.ContainerResponseContext; -import jakarta.ws.rs.container.ContainerResponseFilter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import javax.annotation.Nullable; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.HttpChannel; -import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.util.component.Container; -import org.eclipse.jetty.util.component.LifeCycle; -import org.glassfish.jersey.server.ExtendedUriInfo; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil; - -/** - * Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get - * templated Jersey request paths, it implements {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give - * itself access to the template. It is limited to {@link TrafficSource#HTTP} requests. - *

- * It implements {@link LifeCycle.Listener} without overriding methods, so that it can be an event listener that - * Dropwizard will attach to the container—the {@link Container.Listener} implementation is where it attaches - * itself to any {@link Connector}s. - * - * @see MetricsRequestEventListener - */ -public class MetricsHttpChannelListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener, - ContainerResponseFilter { - - private static final Logger logger = LoggerFactory.getLogger(MetricsHttpChannelListener.class); - - private record RequestInfo(String path, String method, int statusCode, @Nullable String userAgent) { - - } - - private final ClientReleaseManager clientReleaseManager; - private final Set servletPaths; - - // Use the same counter namespace as MetricsRequestEventListener for continuity - public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME; - public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME; - @VisibleForTesting - static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME; - @VisibleForTesting - static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME; - @VisibleForTesting - static final String URI_INFO_PROPERTY_NAME = MetricsHttpChannelListener.class.getName() + ".uriInfo"; - - @VisibleForTesting - static final String PATH_TAG = "path"; - - @VisibleForTesting - static final String METHOD_TAG = "method"; - - @VisibleForTesting - static final String STATUS_CODE_TAG = "status"; - - @VisibleForTesting - static final String TRAFFIC_SOURCE_TAG = "trafficSource"; - - private final MeterRegistry meterRegistry; - - - public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set servletPaths) { - this(Metrics.globalRegistry, clientReleaseManager, servletPaths); - } - - @VisibleForTesting - MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager, - final Set servletPaths) { - this.meterRegistry = meterRegistry; - this.clientReleaseManager = clientReleaseManager; - this.servletPaths = servletPaths; - } - - public void configure(final Environment environment) { - // register as ContainerResponseFilter - environment.jersey().register(this); - - // hook into lifecycle events, to react to the Connector being added - environment.lifecycle().addEventListener(this); - } - - @Override - public void onRequestFailure(final Request request, final Throwable failure) { - - if (logger.isDebugEnabled()) { - final RequestInfo requestInfo = getRequestInfo(request); - - logger.debug("Request failure: {} {} ({}) [{}] ", - requestInfo.method(), - requestInfo.path(), - requestInfo.userAgent(), - requestInfo.statusCode(), failure); - } - } - - @Override - public void onResponseFailure(Request request, Throwable failure) { - - if (failure instanceof org.eclipse.jetty.io.EofException) { - // the client disconnected early - return; - } - - final RequestInfo requestInfo = getRequestInfo(request); - - logger.warn("Response failure: {} {} ({}) [{}] ", - requestInfo.method(), - requestInfo.path(), - requestInfo.userAgent(), - requestInfo.statusCode(), failure); - } - - @Override - public void onComplete(final Request request) { - - final RequestInfo requestInfo = getRequestInfo(request); - - final List tags = new ArrayList<>(5); - tags.add(Tag.of(PATH_TAG, requestInfo.path())); - tags.add(Tag.of(METHOD_TAG, requestInfo.method())); - tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(requestInfo.statusCode()))); - tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())); - tags.addAll(UserAgentTagUtil.getLibsignalAndPlatformTags(requestInfo.userAgent())); - - meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); - - meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(request.getResponse().getContentCount()); - meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(request.getContentRead()); - - UserAgentTagUtil.getClientVersionTag(requestInfo.userAgent(), clientReleaseManager).ifPresent( - clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME, - Tags.of(clientVersionTag, UserAgentTagUtil.getPlatformTag(requestInfo.userAgent()))).increment()); - } - - @Override - public void beanAdded(final Container parent, final Object child) { - if (child instanceof Connector connector) { - connector.addBean(this); - } - } - - @Override - public void beanRemoved(final Container parent, final Object child) { - - } - - @Override - public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) - throws IOException { - requestContext.setProperty(URI_INFO_PROPERTY_NAME, requestContext.getUriInfo()); - } - - private RequestInfo getRequestInfo(Request request) { - final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME)) - .map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr)) - .orElseGet(() -> - Optional.ofNullable(request.getPathInfo()) - .filter(servletPaths::contains) - .orElse("unknown") - ); - final String method = Optional.ofNullable(request.getMethod()).orElse("unknown"); - // Response cannot be null, but its status might not always reflect an actual response status, since it gets - // initialized to 200 - final int status = request.getResponse().getStatus(); - - @Nullable final String userAgent = request.getHeader(HttpHeaders.USER_AGENT); - - return new RequestInfo(path, method, status, userAgent); - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java new file mode 100644 index 000000000..9ca1046fd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandler.java @@ -0,0 +1,245 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.metrics; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HttpHeaders; +import io.dropwizard.core.setup.Environment; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerResponseContext; +import jakarta.ws.rs.container.ContainerResponseFilter; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import javax.annotation.Nullable; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.io.Content; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.handler.EventsHandler; +import org.eclipse.jetty.util.component.LifeCycle; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil; + +/** + * Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get + * templated Jersey request paths, it adds a {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give + * itself access to the template. It is limited to {@link TrafficSource#HTTP} requests. + * + * @see MetricsRequestEventListener + */ +public class MetricsHttpEventHandler extends EventsHandler { + + private static final Logger logger = LoggerFactory.getLogger(MetricsHttpEventHandler.class); + + + private final ClientReleaseManager clientReleaseManager; + private final Set servletPaths; + + // Use the same counter namespace as MetricsRequestEventListener for continuity + public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME; + public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME; + @VisibleForTesting + static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME; + @VisibleForTesting + static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME; + @VisibleForTesting + static final String REQUEST_INFO_PROPERTY_NAME = MetricsHttpEventHandler.class.getName() + ".requestInfo"; + + @VisibleForTesting + static final String PATH_TAG = "path"; + + @VisibleForTesting + static final String METHOD_TAG = "method"; + + @VisibleForTesting + static final String STATUS_CODE_TAG = "status"; + + @VisibleForTesting + static final String TRAFFIC_SOURCE_TAG = "trafficSource"; + + private final MeterRegistry meterRegistry; + + @VisibleForTesting + MetricsHttpEventHandler( + final Handler handler, + final MeterRegistry meterRegistry, + final ClientReleaseManager clientReleaseManager, + final Set servletPaths) { + super(handler); + + this.meterRegistry = meterRegistry; + this.clientReleaseManager = clientReleaseManager; + this.servletPaths = servletPaths; + } + + /** + * Configure a {@link MetricsHttpEventHandler} + * + * @param environment A dropwizard {@link org.eclipse.jetty.util.component.Environment} + * @param meterRegistry The meter registry to register metrics with + * @param clientReleaseManager A {@link ClientReleaseManager} that determines what tags to include with metrics + * @param servletPaths An allow-list of paths to include in metric tags for requests that are handled by above + * Jersey + */ + public static void configure(final Environment environment, final MeterRegistry meterRegistry, + final ClientReleaseManager clientReleaseManager, final Set servletPaths) { + // register a filter that will set the initial request info + environment.jersey().register(new SetInfoRequestFilter()); + + // hook into lifecycle events, to react to the Connector being added + environment.lifecycle().addEventListener(new LifeCycle.Listener() { + @Override + public void lifeCycleStarting(LifeCycle event) { + if (event instanceof Server server) { + server.setHandler( + new MetricsHttpEventHandler(server.getHandler(), meterRegistry, clientReleaseManager, servletPaths)); + } + } + }); + } + + private void onResponseFailure(Request request, int status, Throwable failure) { + + if (failure instanceof org.eclipse.jetty.io.EofException) { + // the client disconnected early + return; + } + + final RequestInfo requestInfo = getRequestInfo(request); + + logger.warn("Response failure: {} {} ({}) [{}] ", + requestInfo.method, + requestInfo.path, + requestInfo.userAgent, + status, + failure); + } + + @Override + public void onComplete(Request request, int status, HttpFields headers, Throwable failure) { + + super.onComplete(request, status, headers, failure); + + if (failure != null) { + onResponseFailure(request, status, failure); + } + + final RequestInfo requestInfo = getRequestInfo(request); + + final List tags = new ArrayList<>(5); + tags.add(Tag.of(PATH_TAG, requestInfo.path)); + tags.add(Tag.of(METHOD_TAG, requestInfo.method)); + tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(status))); + tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())); + tags.addAll(UserAgentTagUtil.getLibsignalAndPlatformTags(requestInfo.userAgent)); + + meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); + meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(requestInfo.responseBytes); + meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(requestInfo.requestBytes); + + UserAgentTagUtil.getClientVersionTag(requestInfo.userAgent, clientReleaseManager).ifPresent( + clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME, + Tags.of(clientVersionTag, UserAgentTagUtil.getPlatformTag(requestInfo.userAgent))).increment()); + } + + @Override + protected void onRequestRead(final Request request, final Content.Chunk chunk) { + super.onRequestRead(request, chunk); + if (chunk != null) { + getRequestInfo(request).requestBytes += chunk.remaining(); + } + } + + @Override + protected void onResponseWrite(final Request request, final boolean last, final ByteBuffer content) { + super.onResponseWrite(request, last, content); + if (content != null) { + getRequestInfo(request).responseBytes += content.remaining(); + } + } + + private RequestInfo getRequestInfo(Request request) { + Object obj = request.getAttribute(REQUEST_INFO_PROPERTY_NAME); + if (obj != null && obj instanceof RequestInfo requestInfo) { + return requestInfo; + } + + // Our ContainerResponseFilter has not run yet. It should eventually run, and will override the path we set here. + // It may not run if this is a websocket upgrade request, a request handled by jetty directly, or a higher priority + // filter aborted the request by throwing an exception, in which case we'll use this path. To avoid giving every + // incorrect path a unique tag we check against a configured list of paths that we know would skip the filter. + final RequestInfo newInfo = new RequestInfo( + Optional.ofNullable(request.getHttpURI().getPath()).filter(servletPaths::contains).orElse("unknown"), + Optional.ofNullable(request.getMethod()).orElse("unknown"), + request.getHeaders().get(HttpHeaders.USER_AGENT)); + + request.setAttribute(REQUEST_INFO_PROPERTY_NAME, newInfo); + return newInfo; + } + + @VisibleForTesting + static class RequestInfo { + + private String path; + private final String method; + private final @Nullable String userAgent; + + private long requestBytes; + private long responseBytes; + + RequestInfo(@NotNull String path, @NotNull String method, @Nullable String userAgent) { + this.path = path; + this.method = method; + this.userAgent = userAgent; + this.requestBytes = 0; + this.responseBytes = 0; + } + + @Override + public boolean equals(final Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + RequestInfo that = (RequestInfo) o; + return requestBytes == that.requestBytes && responseBytes == that.responseBytes && Objects.equals(path, that.path) + && Objects.equals(method, that.method) && Objects.equals(userAgent, that.userAgent); + } + + @Override + public int hashCode() { + return Objects.hash(path, method, userAgent, requestBytes, responseBytes); + } + } + + @VisibleForTesting + static class SetInfoRequestFilter implements ContainerResponseFilter { + + @Override + public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) { + // Construct the templated URI path. If no matching path is found, this will be "" + final String path = UriInfoUtil.getPathTemplate((ExtendedUriInfo) requestContext.getUriInfo()); + final Object obj = requestContext.getProperty(REQUEST_INFO_PROPERTY_NAME); + if (obj != null && obj instanceof RequestInfo requestInfo) { + requestInfo.path = path; + } else { + requestContext.setProperty(REQUEST_INFO_PROPERTY_NAME, + new RequestInfo(path, requestContext.getMethod(), requestContext.getHeaderString(HttpHeaders.USER_AGENT))); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java index 1cd3b6d45..2d3c9b6cb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java @@ -27,7 +27,7 @@ import org.whispersystems.websocket.WebSocketResourceProvider; /** * Gathers and reports request-level metrics for WebSocket traffic only. - * For HTTP traffic, use {@link MetricsHttpChannelListener}. + * For HTTP traffic, use {@link MetricsHttpEventHandler}. */ public class MetricsRequestEventListener implements RequestEventListener { @@ -62,7 +62,7 @@ public class MetricsRequestEventListener implements RequestEventListener { this(trafficSource, Metrics.globalRegistry, clientReleaseManager); if (trafficSource == TrafficSource.HTTP) { - logger.warn("Use {} for HTTP traffic", MetricsHttpChannelListener.class.getName()); + logger.warn("Use {} for HTTP traffic", MetricsHttpEventHandler.class.getName()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java index 06f786649..04b29941f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtil.java @@ -21,7 +21,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Optional; -import org.eclipse.jetty.util.resource.Resource; +import org.eclipse.jetty.util.resource.PathResourceFactory; import org.eclipse.jetty.util.security.CertificateUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,7 +37,7 @@ public class TlsCertificateExpirationUtil { final KeyStore keyStore; try { - keyStore = CertificateUtils.getKeyStore(Resource.newResource(keyStorePath), keyStoreType, keyStoreProvider, + keyStore = CertificateUtils.getKeyStore(new PathResourceFactory().newResource(keyStorePath), keyStoreType, keyStoreProvider, keyStorePassword); } catch (Exception e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index c88fa5e29..4ba267dfc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -8,14 +8,14 @@ package org.whispersystems.textsecuregcm.websocket; import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader; import com.google.common.net.HttpHeaders; -import javax.annotation.Nullable; import io.dropwizard.auth.basic.BasicCredentials; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import java.util.Optional; +import javax.annotation.Nullable; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; -import java.util.Optional; public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { @@ -27,7 +27,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator authenticate(final UpgradeRequest request) + public Optional authenticate(final JettyServerUpgradeRequest request) throws InvalidCredentialsException { @Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java index 4183026b3..cce4dedbd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java @@ -15,7 +15,6 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ManagedAsync; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -35,7 +34,6 @@ public class BufferingInterceptorIntegrationTest { environment.jersey().register(testController); environment.jersey().register(new BufferingInterceptor()); environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10)); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java index da79f811f..d5204a1c1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java @@ -6,7 +6,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -19,39 +18,40 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; import java.io.IOException; import java.net.URI; +import java.nio.ByteBuffer; import java.time.Duration; -import java.util.EnumSet; import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; import org.whispersystems.websocket.WebSocketResourceProviderFactory; -import org.whispersystems.websocket.WebsocketHeaders; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class ProvisioningTimeoutIntegrationTest { private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = @@ -77,9 +77,9 @@ public class ProvisioningTimeoutIntegrationTest { CompletableFuture provisioningAddressFuture = new CompletableFuture<>(); @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE && webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) { MessageProtos.ProvisioningAddress provisioningAddress = @@ -92,7 +92,7 @@ public class ProvisioningTimeoutIntegrationTest { } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); } - super.onWebSocketBinary(payload, offset, length); + super.onWebSocketBinary(payload, callback); } } @@ -106,21 +106,17 @@ public class ProvisioningTimeoutIntegrationTest { final WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, webSocketConfiguration); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.setConnectListener( new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5))); final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + REMOTE_ADDRESS_ATTRIBUTE_NAME); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - final ServletRegistration.Dynamic websocketServlet = environment.servlets() - .addServlet("WebSocket", webSocketServlet); - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping("/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java index d74e599b0..606ca0739 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java @@ -9,8 +9,6 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; import jakarta.ws.rs.GET; import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; @@ -20,19 +18,20 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import java.io.IOException; import java.net.URI; -import java.util.EnumSet; import java.util.Optional; import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ServerProperties; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.websocket.WebSocketResourceProviderFactory; @@ -41,6 +40,7 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class WebsocketResourceProviderIntegrationTest { private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(TestApplication.class); @@ -72,9 +72,6 @@ public class WebsocketResourceProviderIntegrationTest { new WebSocketEnvironment<>(environment, webSocketConfiguration); environment.jersey().register(testController); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.jersey().register(testController); webSocketEnvironment.jersey().register(new RemoteAddressFilter()); webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class))); @@ -85,15 +82,13 @@ public class WebsocketResourceProviderIntegrationTest { final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + REMOTE_ADDRESS_ATTRIBUTE_NAME); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping("/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); - final ServletRegistration.Dynamic websocketServlet = - environment.servlets().addServlet("WebSocket", webSocketServlet); - - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java index 28ea6d113..e7417ebc1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java @@ -15,8 +15,8 @@ import java.util.List; import java.util.Optional; import java.util.UUID; import javax.annotation.Nullable; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java index 96ae5283f..edc6c576e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java @@ -13,20 +13,17 @@ import io.dropwizard.core.Configuration; import io.dropwizard.core.setup.Environment; import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; import jakarta.ws.rs.GET; import jakarta.ws.rs.Path; import jakarta.ws.rs.client.Client; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.core.Context; -import java.io.IOException; import java.net.InetAddress; import java.net.URI; import java.nio.ByteBuffer; import java.security.Principal; import java.time.Duration; import java.util.Arrays; -import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -35,14 +32,15 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.security.auth.Subject; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.util.HostPort; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -55,6 +53,7 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class RemoteAddressFilterIntegrationTest { private static final String WEBSOCKET_PREFIX = "/websocket"; @@ -131,7 +130,7 @@ class RemoteAddressFilterIntegrationTest { } } - private static class ClientEndpoint implements WebSocketListener { + public static class ClientEndpoint implements Session.Listener.AutoDemanding { private final String requestPath; private final CompletableFuture responseFuture; @@ -145,22 +144,19 @@ class RemoteAddressFilterIntegrationTest { } @Override - public void onWebSocketConnect(final Session session) { + public void onWebSocketOpen(final Session session) { final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath, List.of("Accept: application/json"), Optional.empty()).toByteArray(); - try { - session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); - } catch (IOException e) { - throw new RuntimeException(e); - } + + session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); } @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { assert 200 == webSocketMessage.getResponseMessage().getStatus(); @@ -206,10 +202,6 @@ class RemoteAddressFilterIntegrationTest { public void run(final Configuration configuration, final Environment environment) throws Exception { - environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH, - WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); - environment.jersey().register(new TestRemoteAddressController()); // WebSocket set up @@ -220,15 +212,14 @@ class RemoteAddressFilterIntegrationTest { webSocketEnvironment.jersey().register(new TestWebSocketController()); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + webSocketEnvironment, TestPrincipal.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet) - .addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); - + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { + container.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH, webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java similarity index 73% rename from service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java index 40b7d040d..df831156b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerIntegrationTest.java @@ -27,7 +27,6 @@ import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import jakarta.annotation.Priority; -import jakarta.servlet.DispatcherType; import jakarta.ws.rs.GET; import jakarta.ws.rs.InternalServerErrorException; import jakarta.ws.rs.NotAuthorizedException; @@ -44,7 +43,6 @@ import java.io.IOException; import java.net.URI; import java.security.Principal; import java.time.Duration; -import java.util.EnumSet; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -54,25 +52,28 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.stream.Stream; import javax.security.auth.Subject; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.HttpChannel; +import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.util.component.Container; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.handler.EventsHandler; import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.filters.PriorityFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.websocket.WebSocketResourceProviderFactory; @@ -80,7 +81,8 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; @ExtendWith(DropwizardExtensionsSupport.class) -class MetricsHttpChannelListenerIntegrationTest { +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +class MetricsHttpEventHandlerIntegrationTest { private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP; private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class); @@ -90,7 +92,7 @@ class MetricsHttpChannelListenerIntegrationTest { private static final AtomicReference COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>(); private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( - MetricsHttpChannelListenerIntegrationTest.TestApplication.class); + MetricsHttpEventHandlerIntegrationTest.TestApplication.class); @AfterEach void teardown() { @@ -111,9 +113,9 @@ class MetricsHttpChannelListenerIntegrationTest { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final Map counterMap = Map.of( - MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER, - MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, - MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER + MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, + MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, + MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER ); when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); @@ -147,7 +149,7 @@ class MetricsHttpChannelListenerIntegrationTest { assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); - verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(REQUEST_COUNTER).increment(); final Iterable tagIterable = tagCaptor.getValue(); @@ -158,11 +160,11 @@ class MetricsHttpChannelListenerIntegrationTest { } assertEquals(6, tags.size()); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(expectedStatus)))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, expectedTagPath))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(expectedStatus)))); assertTrue( - tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } @@ -194,24 +196,19 @@ class MetricsHttpChannelListenerIntegrationTest { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final Map counterMap = Map.of( - MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER, - MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, - MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER + MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, + MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, + MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER ); when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); - client.connect(new WebSocketListener() { - @Override - public void onWebSocketConnect(final Session session) { - session.close(1000, "OK"); - } - }, + client.connect(new AutoClosingWebSocketSessionListener(), URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest); assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); - verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(REQUEST_COUNTER).increment(); final Iterable tagIterable = tagCaptor.getValue(); @@ -222,11 +219,11 @@ class MetricsHttpChannelListenerIntegrationTest { } assertEquals(6, tags.size()); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket"))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101)))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, "/v1/websocket"))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(101)))); assertTrue( - tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } @@ -248,17 +245,16 @@ class MetricsHttpChannelListenerIntegrationTest { public void run(final Configuration configuration, final Environment environment) throws Exception { - final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener( - METER_REGISTRY, - mock(ClientReleaseManager.class), - Set.of("/v1/websocket") - ); + MetricsHttpEventHandler.configure(environment, METER_REGISTRY, mock(ClientReleaseManager.class), Set.of("/v1/websocket")); - metricsHttpChannelListener.configure(environment); - environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE)); - - environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + environment.lifecycle().addEventListener(new LifeCycle.Listener() { + @Override + public void lifeCycleStarting(final LifeCycle event) { + if (event instanceof Server server) { + server.setHandler(new TestListener(server.getHandler(), COUNT_DOWN_LATCH_FUTURE_REFERENCE)); + } + } + }); environment.jersey().register(new TestResource()); environment.jersey().register(new TestAuthFilter()); @@ -271,14 +267,15 @@ class MetricsHttpChannelListenerIntegrationTest { webSocketEnvironment.jersey().register(new TestResource()); - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + webSocketEnvironment, TestPrincipal.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - environment.servlets().addServlet("WebSocket", webSocketServlet) - .addMapping("/v1/websocket"); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), + (servletContext, container) -> { + container.addMapping("/v1/websocket", webSocketServlet); + PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); + }); } } @@ -294,36 +291,23 @@ class MetricsHttpChannelListenerIntegrationTest { } /** - * A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on + * A simple listener to signal that {@link EventsHandler} has completed its work, since its onComplete() is on * a different thread from the one that sends the response, creating a race condition between the listener and the * test assertions */ - static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener { + static class TestListener extends EventsHandler { private final AtomicReference completableFutureAtomicReference; - TestListener(AtomicReference countDownLatchReference) { - + TestListener(final Handler handler, AtomicReference countDownLatchReference) { + super(handler); this.completableFutureAtomicReference = countDownLatchReference; } @Override - public void onComplete(final Request request) { + public void onComplete(Request request, int status, HttpFields headers, Throwable failure) { completableFutureAtomicReference.get().countDown(); } - - @Override - public void beanAdded(final Container parent, final Object child) { - if (child instanceof Connector connector) { - connector.addBean(this); - } - } - - @Override - public void beanRemoved(final Container parent, final Object child) { - - } - } @Path("/v1/test") @@ -365,4 +349,11 @@ class MetricsHttpChannelListenerIntegrationTest { } } + public static class AutoClosingWebSocketSessionListener implements Session.Listener.AutoDemanding { + @Override + public void onWebSocketOpen(final Session session) { + session.close(1000, "OK", Callback.NOOP); + } + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java similarity index 50% rename from service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java index 7405bbd04..b2a0479aa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpEventHandlerTest.java @@ -18,13 +18,18 @@ import com.google.common.net.HttpHeaders; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; -import java.util.Collections; +import java.nio.ByteBuffer; import java.util.HashSet; import java.util.List; import java.util.Set; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpURI; +import org.eclipse.jetty.io.Content; import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Response; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.ExtendedUriInfo; import org.glassfish.jersey.uri.UriTemplate; import org.junit.jupiter.api.BeforeEach; @@ -34,7 +39,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -class MetricsHttpChannelListenerTest { +class MetricsHttpEventHandlerTest { + private final static String USER_AGENT = "Signal-Android/6.53.7 (Android 8.1)"; private MeterRegistry meterRegistry; private Counter requestCounter; @@ -42,7 +48,7 @@ class MetricsHttpChannelListenerTest { private Counter responseBytesCounter; private Counter requestBytesCounter; private ClientReleaseManager clientReleaseManager; - private MetricsHttpChannelListener listener; + private MetricsHttpEventHandler listener; @BeforeEach void setup() { @@ -52,26 +58,27 @@ class MetricsHttpChannelListenerTest { responseBytesCounter = mock(Counter.class); requestBytesCounter = mock(Counter.class); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestsByVersionCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(responseBytesCounter); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) + when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); clientReleaseManager = mock(ClientReleaseManager.class); - listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet()); + listener = new MetricsHttpEventHandler(null, meterRegistry, clientReleaseManager, Set.of("/test")); } - @Test + @ParameterizedTest + @ValueSource(booleans = {true, false}) @SuppressWarnings("unchecked") - void testRequests() { + void testRequests(boolean pathFromFilter) { final String path = "/test"; final String method = "GET"; final int statusCode = 200; @@ -81,28 +88,39 @@ class MetricsHttpChannelListenerTest { final Request request = mock(Request.class); when(request.getMethod()).thenReturn(method); - when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)"); + + final HttpFields.Mutable requestHeaders = HttpFields.build(); + requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT); + when(request.getHeaders()).thenReturn(requestHeaders); when(request.getHttpURI()).thenReturn(httpUri); + if (pathFromFilter) { + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT)); + } else { + when(request.setAttribute(eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), any())).thenAnswer(invocation -> { + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(invocation.getArgument(1)); + return null; + }); + } + + final Response response = mock(Response.class); when(response.getStatus()).thenReturn(statusCode); - when(response.getContentCount()).thenReturn(1024L); - when(request.getResponse()).thenReturn(response); - when(request.getContentRead()).thenReturn(512L); - final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); - when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); - when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); - listener.onComplete(request); + listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); + listener.onResponseWrite(request, true, ByteBuffer.allocate(1024)); + listener.onComplete(request, statusCode, requestHeaders, null); verify(requestCounter).increment(); verify(responseBytesCounter).increment(1024L); verify(requestBytesCounter).increment(512L); - verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); final Set tags = new HashSet<>(); for (final Tag tag : tagCaptor.getValue()) { @@ -110,11 +128,11 @@ class MetricsHttpChannelListenerTest { } assertEquals(6, tags.size()); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, path))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, method))); - assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode)))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, path))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, method))); + assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(statusCode)))); assertTrue( - tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); + tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } @@ -133,23 +151,22 @@ class MetricsHttpChannelListenerTest { final Request request = mock(Request.class); when(request.getMethod()).thenReturn(method); - when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)"); + final HttpFields.Mutable requestHeaders = HttpFields.build(); + requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT); + when(request.getHeaders()).thenReturn(requestHeaders); when(request.getHttpURI()).thenReturn(httpUri); + when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)) + .thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT)); final Response response = mock(Response.class); when(response.getStatus()).thenReturn(statusCode); - when(response.getContentCount()).thenReturn(1024L); - when(request.getResponse()).thenReturn(response); - when(request.getContentRead()).thenReturn(512L); - final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); - when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); - when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); - - listener.onComplete(request); + listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); + listener.onResponseWrite(request, true, ByteBuffer.allocate(1024)); + listener.onComplete(request, statusCode, requestHeaders, null); if (versionActive) { final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); - verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), + verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), tagCaptor.capture()); final Set tags = new HashSet<>(); tags.clear(); @@ -163,7 +180,38 @@ class MetricsHttpChannelListenerTest { } else { verifyNoInteractions(requestsByVersionCounter); } + } + @Test + void testResponseFilterSetsRequestInfo() { + final ContainerRequest request = mock(ContainerRequest.class); + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test"))); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeaders()).thenReturn(null); + when(request.getUriInfo()).thenReturn(extendedUriInfo); + when(request.getHeaderString(HttpHeaders.USER_AGENT)).thenReturn(USER_AGENT); + + new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class)); + + verify(request).setProperty( + eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), + eq(new MetricsHttpEventHandler.RequestInfo("/test", "GET", USER_AGENT))); + } + + @Test + void testResponseFilterModifiesRequestInfo() { + final MetricsHttpEventHandler.RequestInfo requestInfo = + new MetricsHttpEventHandler.RequestInfo("unknown", "POST", USER_AGENT); + + final ContainerRequest request = mock(ContainerRequest.class); + when(request.getProperty(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)).thenReturn(requestInfo); + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test"))); + when(request.getUriInfo()).thenReturn(extendedUriInfo); + new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class)); + + assertEquals(new MetricsHttpEventHandler.RequestInfo("/test", "POST", USER_AGENT), requestInfo); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index c07e1e2a2..28b92584c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -34,10 +34,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerResponse; @@ -160,11 +159,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)"); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)"))); @@ -176,15 +173,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -228,11 +225,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( @@ -242,15 +237,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -297,11 +292,9 @@ class MetricsRequestEventListenerTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); - final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); final UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( @@ -311,15 +304,15 @@ class MetricsRequestEventListenerTest { when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) .thenReturn(requestBytesCounter); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java index 3b3c53adf..41c6c088f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/TlsCertificateExpirationUtilTest.java @@ -149,15 +149,13 @@ class TlsCertificateExpirationUtilTest { @Test void test() throws Exception { - try (Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64)) { + final Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64); + final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD); - final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD); - - final Map expected = Map.of( - "localhost:EdDSA", EDDSA_EXPIRATION, - "localhost:RSA", RSA_EXPIRATION); - assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD)); - } + final Map expected = Map.of( + "localhost:EdDSA", EDDSA_EXPIRATION, + "localhost:RSA", RSA_EXPIRATION); + assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java index 4c1cdf3f9..44d904646 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java @@ -4,14 +4,6 @@ */ package org.whispersystems.textsecuregcm.tests.util; -import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; -import org.whispersystems.websocket.messages.WebSocketMessage; -import org.whispersystems.websocket.messages.WebSocketMessageFactory; -import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; - -import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Objects; @@ -19,8 +11,14 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Session; +import org.whispersystems.websocket.messages.WebSocketMessage; +import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; -public class TestWebsocketListener implements WebSocketListener { +public class TestWebsocketListener implements Session.Listener.AutoDemanding { private final AtomicLong requestId = new AtomicLong(); private final CompletableFuture started = new CompletableFuture<>(); @@ -34,7 +32,7 @@ public class TestWebsocketListener implements WebSocketListener { @Override - public void onWebSocketConnect(final Session session) { + public void onWebSocketOpen(final Session session) { started.complete(session); } @@ -63,19 +61,15 @@ public class TestWebsocketListener implements WebSocketListener { responseFutures.put(id, future); final byte[] requestBytes = messageFactory.createRequest( Optional.of(id), verb, requestPath, headers, body).toByteArray(); - try { - session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); - } catch (IOException e) { - throw new RuntimeException(e); - } + session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); return future; }); } @Override - public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { responseFutures.get(webSocketMessage.getResponseMessage().getRequestId()) .complete(webSocketMessage.getResponseMessage()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java index 831c4e526..e02282a65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/jetty/TestResource.java @@ -6,12 +6,9 @@ package org.whispersystems.textsecuregcm.util.jetty; import java.io.ByteArrayInputStream; -import java.io.File; -import java.io.IOException; import java.io.InputStream; -import java.net.MalformedURLException; import java.net.URI; -import java.nio.channels.ReadableByteChannel; +import java.nio.file.Path; import java.util.Base64; import org.eclipse.jetty.util.resource.Resource; @@ -30,15 +27,16 @@ public class TestResource extends Resource { } @Override - public boolean isContainedIn(final Resource r) throws MalformedURLException { - return false; + public Path getPath() { + return null; } @Override - public void close() { - + public InputStream newInputStream() { + return new ByteArrayInputStream(data); } + @Override public boolean exists() { return true; @@ -50,13 +48,13 @@ public class TestResource extends Resource { } @Override - public long lastModified() { - return 0; + public boolean isReadable() { + return true; } @Override public long length() { - return 0; + return data.length; } @Override @@ -64,43 +62,19 @@ public class TestResource extends Resource { return null; } - @Override - public File getFile() throws IOException { - return null; - } - @Override public String getName() { return name; } @Override - public InputStream getInputStream() throws IOException { - return new ByteArrayInputStream(data); + public String getFileName() { + return ""; } @Override - public ReadableByteChannel getReadableByteChannel() throws IOException { + public Resource resolve(final String subUriPath) { return null; } - @Override - public boolean delete() throws SecurityException { - return false; - } - - @Override - public boolean renameTo(final Resource dest) throws SecurityException { - return false; - } - - @Override - public String[] list() { - return new String[]{name}; - } - - @Override - public Resource addPath(final String path) throws IOException, MalformedURLException { - return this; - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index c91d91908..32413433d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -39,10 +39,9 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Stream; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ServerProperties; @@ -143,12 +142,12 @@ class LoggingUnhandledExceptionMapperTest { WebSocketResourceProvider provider = createWebsocketProvider(userAgentHeader, session, responseFuture::complete); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory() .createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); responseFuture.get(1, TimeUnit.SECONDS); @@ -179,15 +178,13 @@ class LoggingUnhandledExceptionMapperTest { TestPrincipal.authenticatedTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); doAnswer(answer -> { responseHandler.accept(answer.getArgument(0, ByteBuffer.class)); return null; - }).when(remoteEndpoint).sendBytes(any(), any(WriteCallback.class)); + }).when(session).sendBinary(any(ByteBuffer.class), any(Callback.class)); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java index 4ecd722f3..55e7e7cc3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -19,7 +19,7 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Stream; import javax.annotation.Nullable; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -44,7 +44,7 @@ class WebSocketAccountAuthenticatorTest { private AccountAuthenticator accountAuthenticator; - private UpgradeRequest upgradeRequest; + private JettyServerUpgradeRequest upgradeRequest; @BeforeEach void setUp() { @@ -56,7 +56,7 @@ class WebSocketAccountAuthenticatorTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.empty()); - upgradeRequest = mock(UpgradeRequest.class); + upgradeRequest = mock(JettyServerUpgradeRequest.class); } @ParameterizedTest diff --git a/websocket-resources/pom.xml b/websocket-resources/pom.xml index 156ff3dd7..8ed3c7d4f 100644 --- a/websocket-resources/pom.xml +++ b/websocket-resources/pom.xml @@ -13,15 +13,19 @@ org.eclipse.jetty.websocket - websocket-jetty-api + jetty-websocket-jetty-api org.eclipse.jetty.websocket - websocket-jetty-server + jetty-websocket-jetty-server - org.eclipse.jetty.websocket - websocket-servlet + org.eclipse.jetty.ee10.websocket + jetty-ee10-websocket-jetty-server + + + org.eclipse.jetty.ee10.websocket + jetty-ee10-websocket-servlet io.dropwizard diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java index 4f6d0d6a4..f2c900447 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java @@ -12,9 +12,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.exceptions.WebSocketException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,15 +28,13 @@ public class WebSocketClient { private static final SecureRandom SECURE_RANDOM = new SecureRandom(); private final Session session; - private final RemoteEndpoint remoteEndpoint; private final WebSocketMessageFactory messageFactory; private final Map> pendingRequestMapper; private final Instant created; - public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, WebSocketMessageFactory messageFactory, + public WebSocketClient(Session session, WebSocketMessageFactory messageFactory, Map> pendingRequestMapper) { this.session = session; - this.remoteEndpoint = remoteEndpoint; this.messageFactory = messageFactory; this.pendingRequestMapper = pendingRequestMapper; this.created = Instant.now(); @@ -55,9 +52,9 @@ public class WebSocketClient { WebSocketMessage requestMessage = messageFactory.createRequest(Optional.of(requestId), verb, path, headers, body); try { - remoteEndpoint.sendBytes(ByteBuffer.wrap(requestMessage.toByteArray()), new WriteCallback() { + session.sendBinary(ByteBuffer.wrap(requestMessage.toByteArray()), new Callback() { @Override - public void writeFailed(Throwable x) { + public void fail(Throwable x) { logger.debug("Write failed", x); pendingRequestMapper.remove(requestId); future.completeExceptionally(x); @@ -85,9 +82,9 @@ public class WebSocketClient { } public void close(final int code, final String message) { - session.close(code, message, new WriteCallback() { + session.close(code, message, new Callback() { @Override - public void writeFailed(final Throwable throwable) { + public void fail(final Throwable throwable) { try { session.disconnect(); } catch (final Exception e) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 0cd80f02f..b3517ab63 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -24,10 +24,8 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.WebSocketListener; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.exceptions.MessageTooLargeException; import org.glassfish.jersey.internal.MapPropertiesDelegate; import org.glassfish.jersey.server.ApplicationHandler; @@ -47,7 +45,7 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class WebSocketResourceProvider implements WebSocketListener { +public class WebSocketResourceProvider implements Session.Listener.AutoDemanding { /** * A static exception instance passed to outstanding requests (via {@code completeExceptionally} in @@ -68,7 +66,6 @@ public class WebSocketResourceProvider implements WebSocket private final String remoteAddressPropertyName; private Session session; - private RemoteEndpoint remoteEndpoint; private WebSocketSessionContext context; private static final Set EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade"); @@ -92,11 +89,10 @@ public class WebSocketResourceProvider implements WebSocket } @Override - public void onWebSocketConnect(Session session) { + public void onWebSocketOpen(Session session) { this.session = session; - this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext( - new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); + new WebSocketClient(session, messageFactory, requestMap)); this.context.setAuthenticated(reusableAuth.orElse(null)); this.session.setIdleTimeout(idleTimeout); @@ -121,9 +117,9 @@ public class WebSocketResourceProvider implements WebSocket } @Override - public void onWebSocketBinary(byte[] payload, int offset, int length) { + public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { try { - WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); switch (webSocketMessage.getType()) { case REQUEST_MESSAGE: @@ -258,7 +254,7 @@ public class WebSocketResourceProvider implements WebSocket } private void close(Session session, int status, String message) { - session.close(status, message); + session.close(status, message, Callback.NOOP); } private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response, @@ -277,7 +273,7 @@ public class WebSocketResourceProvider implements WebSocket Optional.ofNullable(body)) .toByteArray(); - remoteEndpoint.sendBytes(ByteBuffer.wrap(responseBytes), WriteCallback.NOOP); + session.sendBinary(ByteBuffer.wrap(responseBytes), Callback.NOOP); } } @@ -289,7 +285,7 @@ public class WebSocketResourceProvider implements WebSocket getHeaderList(error.getStringHeaders()), Optional.empty()); - remoteEndpoint.sendBytes(ByteBuffer.wrap(response.toByteArray()), WriteCallback.NOOP); + session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP); } } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 4532794d4..c18f3b1e2 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -13,11 +13,9 @@ import java.security.Principal; import java.util.Map; import java.util.Optional; import org.apache.commons.lang3.StringUtils; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.eclipse.jetty.websocket.server.JettyWebSocketCreator; -import org.eclipse.jetty.websocket.server.JettyWebSocketServlet; -import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator; import org.glassfish.jersey.CommonProperties; import org.glassfish.jersey.server.ApplicationHandler; import org.slf4j.Logger; @@ -25,23 +23,20 @@ import org.slf4j.LoggerFactory; import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.setup.WebSocketEnvironment; -public class WebSocketResourceProviderFactory extends JettyWebSocketServlet implements - JettyWebSocketCreator { +public class WebSocketResourceProviderFactory implements JettyWebSocketCreator { private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); private final WebSocketEnvironment environment; private final ApplicationHandler jerseyApplicationHandler; - private final WebSocketConfiguration configuration; private final String remoteAddressPropertyName; public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass, - WebSocketConfiguration configuration, String remoteAddressPropertyName) { + String remoteAddressPropertyName) { this.environment = environment; environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); @@ -55,7 +50,6 @@ public class WebSocketResourceProviderFactory extends Jetty this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); - this.configuration = configuration; this.remoteAddressPropertyName = remoteAddressPropertyName; } @@ -89,6 +83,7 @@ public class WebSocketResourceProviderFactory extends Jetty // Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database). // If that happens, we don't want to incorrectly tell clients that they provided bad credentials. logger.warn("Authentication failure", e); + try { response.sendError(500, "Failure"); } catch (final IOException ignored) { @@ -97,13 +92,6 @@ public class WebSocketResourceProviderFactory extends Jetty } } - @Override - public void configure(JettyWebSocketServletFactory factory) { - factory.setCreator(this); - factory.setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize()); - factory.setMaxTextMessageSize(configuration.getMaxTextMessageSize()); - } - private String getRemoteAddress(JettyServerUpgradeRequest request) { final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName); if (StringUtils.isBlank(remoteAddress)) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java index bf0593dc0..f2f49cc9f 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java @@ -7,8 +7,8 @@ package org.whispersystems.websocket.auth; import java.security.Principal; import java.util.Optional; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; public interface AuthenticatedWebSocketUpgradeFilter { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index 7836b5b8d..4974a707e 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -6,7 +6,7 @@ package org.whispersystems.websocket.auth; import java.security.Principal; import java.util.Optional; -import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; public interface WebSocketAuthenticator { @@ -20,5 +20,5 @@ public interface WebSocketAuthenticator { * * @throws InvalidCredentialsException if credentials were provided, but could not be authenticated */ - Optional authenticate(UpgradeRequest request) throws InvalidCredentialsException; + Optional authenticate(JettyServerUpgradeRequest request) throws InvalidCredentialsException; } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java index d24a7e09b..bbaeb75cc 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/WebSocketMessageFactory.java @@ -5,22 +5,23 @@ package org.whispersystems.websocket.messages; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public interface WebSocketMessageFactory { - public WebSocketMessage parseMessage(byte[] serialized, int offset, int len) + WebSocketMessage parseMessage(ByteBuffer serialized) throws InvalidMessageException; - public WebSocketMessage createRequest(Optional requestId, - String verb, String path, - List headers, - Optional body); + WebSocketMessage createRequest(Optional requestId, + String verb, String path, + List headers, + Optional body); - public WebSocketMessage createResponse(long requestId, int status, String message, - List headers, - Optional body); + WebSocketMessage createResponse(long requestId, int status, String message, + List headers, + Optional body); } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java index 909673363..8d021cc81 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessage.java @@ -10,14 +10,15 @@ import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketRequestMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import java.nio.ByteBuffer; public class ProtobufWebSocketMessage implements WebSocketMessage { private final SubProtocol.WebSocketMessage message; - ProtobufWebSocketMessage(byte[] buffer, int offset, int length) throws InvalidMessageException { + ProtobufWebSocketMessage(ByteBuffer buffer) throws InvalidMessageException { try { - this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer, offset, length)); + this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer)); if (getType() == Type.REQUEST_MESSAGE) { if (!message.getRequest().hasVerb() || !message.getRequest().hasPath()) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java index d4f25bd89..33c108771 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/messages/protobuf/ProtobufWebSocketMessageFactory.java @@ -9,16 +9,17 @@ import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; public class ProtobufWebSocketMessageFactory implements WebSocketMessageFactory { @Override - public WebSocketMessage parseMessage(byte[] serialized, int offset, int len) + public WebSocketMessage parseMessage(ByteBuffer serialized) throws InvalidMessageException { - return new ProtobufWebSocketMessage(serialized, offset, len); + return new ProtobufWebSocketMessage(serialized); } @Override diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 31e44e09f..4cc5e98f4 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -19,17 +19,15 @@ import java.io.IOException; import java.security.Principal; import java.util.Optional; import javax.security.auth.Subject; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; +import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; public class WebSocketResourceProviderFactoryTest { @@ -60,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest { when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException()); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNull(connection); @@ -80,16 +77,15 @@ public class WebSocketResourceProviderFactoryTest { final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(request.getHttpServletRequest()).thenReturn(httpServletRequest); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNotNull(connection); verifyNoMoreInteractions(response); verify(authenticator).authenticate(eq(request)); - ((WebSocketResourceProvider) connection).onWebSocketConnect(mock(Session.class)); + ((WebSocketResourceProvider) connection).onWebSocketOpen(mock(Session.class)); assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); assertEquals(((WebSocketResourceProvider) connection).getContext().getAuthenticated(), account); @@ -103,7 +99,6 @@ public class WebSocketResourceProviderFactoryTest { WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); @@ -112,20 +107,6 @@ public class WebSocketResourceProviderFactoryTest { verify(authenticator).authenticate(eq(request)); } - @Test - void testConfigure() { - JettyWebSocketServletFactory servletFactory = mock(JettyWebSocketServletFactory.class); - when(environment.jersey()).thenReturn(jerseyEnvironment); - - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, - Account.class, - mock(WebSocketConfiguration.class), - REMOTE_ADDRESS_PROPERTY_NAME); - factory.configure(servletFactory); - - verify(servletFactory).setCreator(eq(factory)); - } - @Test void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException { final Account account = new Account(); @@ -137,12 +118,11 @@ public class WebSocketResourceProviderFactoryTest { final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(request.getHttpServletRequest()).thenReturn(httpServletRequest); - final AuthenticatedWebSocketUpgradeFilter filter = mock(AuthenticatedWebSocketUpgradeFilter.class); when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter); final WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + REMOTE_ADDRESS_PROPERTY_NAME); assertNotNull(factory.createWebSocket(request, response)); verify(filter).handleAuthentication(reusableAuth, request, response); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index 953190bac..3e281dc40 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -47,11 +47,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import org.eclipse.jetty.websocket.api.CloseStatus; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; +import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerResponse; @@ -90,11 +88,10 @@ class WebSocketResourceProviderTest { when(session.getUpgradeRequest()).thenReturn(request); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); ArgumentCaptor contextArgumentCaptor = ArgumentCaptor.forClass( WebSocketSessionContext.class); @@ -112,11 +109,9 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); ContainerResponse response = mock(ContainerResponse.class); when(response.getStatus()).thenReturn(200); @@ -146,16 +141,15 @@ class WebSocketResourceProviderTest { return CompletableFuture.completedFuture(response); }); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); @@ -169,7 +163,7 @@ class WebSocketResourceProviderTest { assertThat(bundledRequest.getPath(false)).isEqualTo("bar"); verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response)); - verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( responseCaptor.getValue().array()); @@ -189,25 +183,22 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn( CompletableFuture.failedFuture(new IllegalStateException("foo"))); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); - verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); @@ -221,7 +212,7 @@ class WebSocketResourceProviderTest { ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( responseCaptor.getValue().array()); @@ -245,22 +236,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -285,22 +274,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -325,22 +312,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -365,22 +350,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -404,22 +387,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -444,22 +425,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -484,23 +463,21 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -525,23 +502,21 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -567,22 +542,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); @@ -607,22 +580,20 @@ class WebSocketResourceProviderTest { new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(session.getRemote()).thenReturn(remoteEndpoint); - provider.onWebSocketConnect(session); + provider.onWebSocketOpen(session); byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive", new LinkedList<>(), Optional.empty()).toByteArray(); - provider.onWebSocketBinary(message, 0, message.length); + provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class)); + verify(session).sendBinary(requestCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor); assertThat(requestMessage.getVerb()).isEqualTo("GET"); @@ -632,11 +603,11 @@ class WebSocketResourceProviderTest { byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK", new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray(); - provider.onWebSocketBinary(clientResponse, 0, clientResponse.length); + provider.onWebSocketBinary(ByteBuffer.wrap(clientResponse), Callback.NOOP); ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint, times(2)).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + verify(session, times(2)).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);