Add static servlet paths to MetricsHttpChannelListener

This commit is contained in:
Chris Eager
2024-02-14 15:27:13 -06:00
committed by Chris Eager
parent f90ccd3391
commit 9ce2b7555c
4 changed files with 169 additions and 22 deletions

View File

@@ -27,8 +27,10 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
@@ -38,6 +40,7 @@ import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.annotation.Priority;
import javax.security.auth.Subject;
import javax.servlet.DispatcherType;
import javax.ws.rs.GET;
import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotAuthorizedException;
@@ -55,13 +58,21 @@ import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@@ -148,6 +159,64 @@ class MetricsHttpChannelListenerIntegrationTest {
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}
@Nested
class WebSocket {
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
@Test
void testWebSocketUpgrade() throws Exception {
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)");
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> MetricsHttpChannelListener.REQUEST_COUNTER_NAME.equals(a.getArgument(0, String.class))
? COUNTER
: mock(Counter.class))
.thenReturn(COUNTER);
client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest)
.get(1, TimeUnit.SECONDS);
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagIterable) {
tags.add(tag);
}
assertEquals(5, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}
}
static Stream<Arguments> testSimplePath() {
return Stream.of(
Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!", 200),
@@ -166,11 +235,16 @@ class MetricsHttpChannelListenerIntegrationTest {
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY,
mock(ClientReleaseManager.class));
mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);
metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE));
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter());
@@ -185,9 +259,11 @@ class MetricsHttpChannelListenerIntegrationTest {
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, "ignored");
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocket", webSocketServlet);
environment.servlets().addServlet("WebSocket", webSocketServlet)
.addMapping("/v1/websocket");
}
}
@@ -273,4 +349,5 @@ class MetricsHttpChannelListenerIntegrationTest {
return false;
}
}
}

View File

@@ -11,12 +11,14 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -27,29 +29,39 @@ import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
class MetricsHttpChannelListenerTest {
private MeterRegistry meterRegistry;
private Counter counter;
private Counter requestCounter;
private Counter requestsByVersionCounter;
private ClientReleaseManager clientReleaseManager;
private MetricsHttpChannelListener listener;
@BeforeEach
void setup() {
meterRegistry = mock(MeterRegistry.class);
counter = mock(Counter.class);
requestCounter = mock(Counter.class);
requestsByVersionCounter = mock(Counter.class);
final ClientReleaseManager clientReleaseManager = mock(ClientReleaseManager.class);
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(false);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestCounter);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestsByVersionCounter);
clientReleaseManager = mock(ClientReleaseManager.class);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
}
@Test
@SuppressWarnings("unchecked")
void testOnEvent() {
void testRequests() {
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
@@ -70,17 +82,15 @@ class MetricsHttpChannelListenerTest {
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(counter);
listener.onComplete(request);
verify(requestCounter).increment();
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagIterable) {
for (final Tag tag : tagCaptor.getValue()) {
tags.add(tag);
}
@@ -92,4 +102,50 @@ class MetricsHttpChannelListenerTest {
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
@SuppressWarnings("unchecked")
void testRequestsByVersion(final boolean versionActive) {
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
final HttpURI httpUri = mock(HttpURI.class);
when(httpUri.getPath()).thenReturn(path);
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
when(request.getHttpURI()).thenReturn(httpUri);
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(request.getResponse()).thenReturn(response);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
listener.onComplete(request);
if (versionActive) {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
tags.clear();
for (final Tag tag : tagCaptor.getValue()) {
tags.add(tag);
}
assertEquals(2, tags.size());
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
} else {
verifyNoInteractions(requestsByVersionCounter);
}
}
}