diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java index 6714510b0..1cd3b6d45 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java @@ -12,6 +12,10 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import javax.annotation.Nullable; import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEventListener; @@ -21,11 +25,6 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil; import org.whispersystems.websocket.WebSocketResourceProvider; -import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; - /** * Gathers and reports request-level metrics for WebSocket traffic only. * For HTTP traffic, use {@link MetricsHttpChannelListener}. @@ -53,6 +52,9 @@ public class MetricsRequestEventListener implements RequestEventListener { @VisibleForTesting static final String TRAFFIC_SOURCE_TAG = "trafficSource"; + @VisibleForTesting + static final String AUTHENTICATED_TAG = "authenticated"; + private final TrafficSource trafficSource; private final MeterRegistry meterRegistry; @@ -86,6 +88,12 @@ public class MetricsRequestEventListener implements RequestEventListener { .map(ContainerResponse::getStatus) .orElse(499)))); tags.add(Tag.of(TRAFFIC_SOURCE_TAG, trafficSource.name().toLowerCase())); + tags.add(Tag.of(AUTHENTICATED_TAG, Optional.ofNullable(event.getContainerRequest().getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY)) + .filter(Optional.class::isInstance) + .map(Optional.class::cast) + .map(Optional::isPresent) + .orElse(false) + .toString())); @Nullable final String userAgent; { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 4844c1583..c07e1e2a2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -47,6 +47,8 @@ import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.uri.UriTemplate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; @@ -129,11 +131,12 @@ class MetricsRequestEventListenerTest { tags.add(tag); } - assertEquals(6, tags.size()); + assertEquals(7, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, path))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, method))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(statusCode)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.AUTHENTICATED_TAG, "false"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "true"))); } @@ -196,11 +199,12 @@ class MetricsRequestEventListenerTest { tags.add(tag); } - assertEquals(6, tags.size()); + assertEquals(7, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.AUTHENTICATED_TAG, "true"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } @@ -261,11 +265,81 @@ class MetricsRequestEventListenerTest { tags.add(tag); } - assertEquals(6, tags.size()); + assertEquals(7, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.AUTHENTICATED_TAG, "true"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); + } + + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testAuthenticated(final boolean authenticated) throws IOException { + final MetricsApplicationEventListener applicationEventListener = mock(MetricsApplicationEventListener.class); + when(applicationEventListener.onRequest(any())).thenReturn(listener); + + final ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(applicationEventListener); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + final Optional maybePrincipal = authenticated ? TestPrincipal.authenticatedTestPrincipal("foo") : Optional.empty(); + final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, maybePrincipal, + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); + + final Session session = mock(Session.class); + final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + final UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( + counter); + when(meterRegistry.counter(eq(MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) + .thenReturn(responseBytesCounter); + when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) + .thenReturn(requestBytesCounter); + + provider.onWebSocketConnect(session); + + final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", + new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + final ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getStatus()).isEqualTo(200); + + verify(meterRegistry).counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + + final Iterable tagIterable = tagCaptor.getValue(); + final Set tags = new HashSet<>(); + + for (final Tag tag : tagIterable) { + tags.add(tag); + } + + assertEquals(7, tags.size()); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, "GET"))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.AUTHENTICATED_TAG, String.valueOf(authenticated)))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); }