diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 52a87fe98..4717aa6d3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -1148,7 +1148,7 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000)); - provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90))); + provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, clientReleaseManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90))); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager)); provisioningEnvironment.jersey().register(new KeepAliveController(redisMessageAvailabilityManager)); provisioningEnvironment.jersey().register(new TimestampResponseFilter()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java index a545ed788..ab7e4e3e8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java @@ -3,95 +3,100 @@ package org.whispersystems.textsecuregcm.metrics; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; -import org.whispersystems.textsecuregcm.util.EnumMapUtil; -import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.session.WebSocketSessionContext; public class OpenWebSocketCounter { - private static final String WEBSOCKET_CLOSED_COUNTER_NAME = name(OpenWebSocketCounter.class, "websocketClosed"); + private final ClientReleaseManager clientReleaseManager; - private final String newConnectionCounterName; - private final String durationTimerName; + private final Tags baseTags; - private final Tags tags; + private final Map openWebsocketsByTags; + private final AtomicInteger totalConnections; - private final Map openWebsocketsByClientPlatform; - private final AtomicInteger openWebsocketsFromUnknownPlatforms; + private static final int MAX_COUNTERS = 4096; - public OpenWebSocketCounter(final String openWebSocketGaugeName, - final String newConnectionCounterName, - final String durationTimerName) { + private static final String OPEN_WEBSOCKET_GAUGE_NAME = name(OpenWebSocketCounter.class, "openWebsockets"); + private static final String TOTAL_CONNECTIONS_GAUGE_NAME = name(OpenWebSocketCounter.class, "totalOpenWebsockets"); + private static final String NEW_CONNECTION_COUNTER_NAME = name(OpenWebSocketCounter.class, "newConnections"); + private static final String WEB_SOCKET_CLOSED_COUNTER_NAME = name(OpenWebSocketCounter.class, "websocketClosed"); + private static final String SESSION_DURATION_TIMER_NAME = name(OpenWebSocketCounter.class, "sessionDuration"); + private static final String GAUGE_COUNT_GAUGE_NAME = name(OpenWebSocketCounter.class, "gaugeCount"); - this(openWebSocketGaugeName, newConnectionCounterName, durationTimerName, Tags.empty()); - } + public OpenWebSocketCounter(final String webSocketType, + final ClientReleaseManager clientReleaseManager) { - public OpenWebSocketCounter(final String openWebSocketGaugeName, - final String newConnectionCounterName, - final String durationTimerName, - final Tags tags) { + this.clientReleaseManager = clientReleaseManager; - this.newConnectionCounterName = newConnectionCounterName; - this.durationTimerName = durationTimerName; + this.baseTags = Tags.of("webSocketType", webSocketType); + this.openWebsocketsByTags = Metrics.gaugeMapSize(GAUGE_COUNT_GAUGE_NAME, baseTags, new ConcurrentHashMap<>()); - this.tags = tags; - - openWebsocketsByClientPlatform = EnumMapUtil.toEnumMap(ClientPlatform.class, - clientPlatform -> buildGauge(openWebSocketGaugeName, clientPlatform.name().toLowerCase(), tags)); - - openWebsocketsFromUnknownPlatforms = buildGauge(openWebSocketGaugeName, "unknown", tags); - } - - private static AtomicInteger buildGauge(final String gaugeName, final String clientPlatformName, final Tags tags) { - return Metrics.gauge(gaugeName, - tags.and(Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)), - new AtomicInteger(0)); + this.totalConnections = Metrics.gauge(TOTAL_CONNECTIONS_GAUGE_NAME, baseTags, new AtomicInteger(0)); } public void countOpenWebSocket(final WebSocketSessionContext context) { final Timer.Sample sample = Timer.start(); - // We have to jump through some hoops here to have something "effectively final" for the close listener, but - // assignable from a `catch` block. - final AtomicInteger openWebSocketCounter; - + @Nullable final UserAgent userAgent; { - AtomicInteger calculatedOpenWebSocketCounter; + UserAgent parsedUserAgent; try { - final ClientPlatform clientPlatform = - UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).platform(); - - calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform); + parsedUserAgent = UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()); } catch (final UnrecognizedUserAgentException e) { - calculatedOpenWebSocketCounter = openWebsocketsFromUnknownPlatforms; + parsedUserAgent = null; } - openWebSocketCounter = calculatedOpenWebSocketCounter; + userAgent = parsedUserAgent; } - openWebSocketCounter.incrementAndGet(); + final Tags tagsWithClientPlatform = baseTags.and(UserAgentTagUtil.getPlatformTag(userAgent)); - final Tags tagsWithClientPlatform = tags.and(UserAgentTagUtil.getPlatformTag(context.getClient().getUserAgent())); + final Optional maybeOpenWebSocketCounter; + { + final Tags tagsWithAdditionalSpecifiers = tagsWithClientPlatform + .and(UserAgentTagUtil.getClientVersionTag(userAgent, clientReleaseManager) + .map(Tags::of) + .orElseGet(Tags::empty)) + .and(UserAgentTagUtil.getAdditionalSpecifierTags(userAgent)); - Metrics.counter(newConnectionCounterName, tagsWithClientPlatform).increment(); + maybeOpenWebSocketCounter = getCounter(tagsWithAdditionalSpecifiers); + } + + maybeOpenWebSocketCounter.ifPresent(AtomicInteger::incrementAndGet); + totalConnections.incrementAndGet(); + + Metrics.counter(NEW_CONNECTION_COUNTER_NAME, tagsWithClientPlatform).increment(); context.addWebsocketClosedListener((_, statusCode, _) -> { - sample.stop(Timer.builder(durationTimerName) + sample.stop(Timer.builder(SESSION_DURATION_TIMER_NAME) .tags(tagsWithClientPlatform) .register(Metrics.globalRegistry)); - openWebSocketCounter.decrementAndGet(); + maybeOpenWebSocketCounter.ifPresent(AtomicInteger::decrementAndGet); + totalConnections.decrementAndGet(); - Metrics.counter(WEBSOCKET_CLOSED_COUNTER_NAME, tagsWithClientPlatform.and("status", String.valueOf(statusCode))) + Metrics.counter(WEB_SOCKET_CLOSED_COUNTER_NAME, tagsWithClientPlatform.and("status", String.valueOf(statusCode))) .increment(); }); } + + private Optional getCounter(final Tags tags) { + // Make a reasonable effort to avoid creating new counters if we're already full + return openWebsocketsByTags.size() >= MAX_COUNTERS + ? Optional.ofNullable(openWebsocketsByTags.get(tags)) + : Optional.of(openWebsocketsByTags.computeIfAbsent(tags, + t -> Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, t, new AtomicInteger(0)))); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 2cd6239db..9f3bca729 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -5,12 +5,8 @@ package org.whispersystems.textsecuregcm.websocket; -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - import com.google.common.annotations.VisibleForTesting; -import io.micrometer.core.instrument.Tags; import java.util.Optional; -import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; @@ -35,13 +31,6 @@ import reactor.core.scheduler.Scheduler; public class AuthenticatedConnectListener implements WebSocketConnectListener { - private static final String OPEN_WEBSOCKET_GAUGE_NAME = name(AuthenticatedConnectListener.class, "openWebsockets"); - private static final String NEW_CONNECTION_COUNTER_NAME = name(AuthenticatedConnectListener.class, "newConnections"); - private static final String CONNECTED_DURATION_TIMER_NAME = - name(AuthenticatedConnectListener.class, "connectedDuration"); - - private static final String AUTHENTICATED_TAG_NAME = "authenticated"; - private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class); private final AccountsManager accountsManager; @@ -72,6 +61,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this(accountsManager, disconnectionRequestManager, + clientReleaseManager, (account, device, client) -> new WebSocketConnection(receiptSender, messagesManager, messageMetrics, @@ -83,26 +73,22 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, - experimentEnrollmentManager), - authenticated -> new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, - NEW_CONNECTION_COUNTER_NAME, - CONNECTED_DURATION_TIMER_NAME, - Tags.of(AUTHENTICATED_TAG_NAME, String.valueOf(authenticated))) + experimentEnrollmentManager) ); } @VisibleForTesting AuthenticatedConnectListener( final AccountsManager accountsManager, final DisconnectionRequestManager disconnectionRequestManager, - final WebSocketConnectionBuilder webSocketConnectionBuilder, - final Function openWebSocketCounterBuilder) { + final ClientReleaseManager clientReleaseManager, + final WebSocketConnectionBuilder webSocketConnectionBuilder) { this.accountsManager = accountsManager; this.disconnectionRequestManager = disconnectionRequestManager; this.webSocketConnectionBuilder = webSocketConnectionBuilder; - openAuthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(true); - openUnauthenticatedWebSocketCounter = openWebSocketCounterBuilder.apply(false); + this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", clientReleaseManager); + this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", clientReleaseManager); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index 8eb6287b0..a76675087 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -18,9 +18,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.controllers.ProvisioningController; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.session.WebSocketSessionContext; @@ -48,14 +48,13 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { private final Duration timeout; public ProvisioningConnectListener(final ProvisioningManager provisioningManager, + final ClientReleaseManager clientReleaseManager, final ScheduledExecutorService timeoutExecutor, final Duration timeout) { this.provisioningManager = provisioningManager; this.timeoutExecutor = timeoutExecutor; this.timeout = timeout; - this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets"), - MetricsUtil.name(getClass(), "newConnections"), - MetricsUtil.name(getClass(), "sessionDuration")); + this.openWebSocketCounter = new OpenWebSocketCounter("provisioning", clientReleaseManager); } @Override @@ -67,7 +66,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { final String provisioningAddress = generateProvisioningAddress(); - context.addWebsocketClosedListener((context1, statusCode, reason) -> { + context.addWebsocketClosedListener((_, _, _) -> { provisioningManager.removeListener(provisioningAddress); timeoutFuture.cancel(false); }); @@ -78,7 +77,7 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { final Optional body = Optional.of(message.getContent().toByteArray()); context.getClient().sendRequest("PUT", "/v1/message", List.of(HeaderUtils.getTimestampHeader()), body) - .whenComplete((ignored, throwable) -> context.getClient().close(1000, "Closed")); + .whenComplete((_, _) -> context.getClient().close(1000, "Closed")); }); context.getClient().sendRequest("PUT", "/v1/address", List.of(HeaderUtils.getTimestampHeader()), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java index da79f811f..de9ab693b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java @@ -6,7 +6,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,10 +41,10 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; import org.whispersystems.websocket.WebSocketResourceProviderFactory; -import org.whispersystems.websocket.WebsocketHeaders; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; @@ -110,7 +109,7 @@ public class ProvisioningTimeoutIntegrationTest { .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.setConnectListener( - new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5))); + new ProvisioningConnectListener(mock(ProvisioningManager.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5))); final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java index 4a6608dc5..f6d0ff1ef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java @@ -12,6 +12,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -23,9 +24,9 @@ import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.identity.IdentityType; -import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.session.WebSocketSessionContext; @@ -54,8 +55,8 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager, disconnectionRequestManager, - (_, _, _) -> authenticatedWebSocketConnection, - _ -> mock(OpenWebSocketCounter.class)); + mock(ClientReleaseManager.class), + (_, _, _) -> authenticatedWebSocketConnection); final Device device = mock(Device.class); when(device.getId()).thenReturn(DEVICE_ID); @@ -81,7 +82,8 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); - verify(webSocketSessionContext).addWebsocketClosedListener(any()); + // We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter + verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any()); verify(authenticatedWebSocketConnection).start(); } @@ -98,7 +100,8 @@ class AuthenticatedConnectListenerTest { verify(webSocketClient).close(eq(1011), anyString()); verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); - verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); + // We expect one call from OpenWebSocketCounter, but none from AuthenticatedConnectListener itself + verify(webSocketSessionContext, times(1)).addWebsocketClosedListener(any()); verify(authenticatedWebSocketConnection, never()).start(); } @@ -114,7 +117,8 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); verify(disconnectionRequestManager).addListener(ACCOUNT_IDENTIFIER, DEVICE_ID, authenticatedWebSocketConnection); - verify(webSocketSessionContext).addWebsocketClosedListener(any()); + // We expect one call from AuthenticatedConnectListener itself and one from OpenWebSocketCounter + verify(webSocketSessionContext, times(2)).addWebsocketClosedListener(any()); verify(authenticatedWebSocketConnection).start(); verify(webSocketClient).close(eq(1011), anyString()); @@ -125,7 +129,8 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener.onWebSocketConnect(webSocketSessionContext); verify(disconnectionRequestManager, never()).addListener(any(), anyByte(), any()); - verify(webSocketSessionContext, never()).addWebsocketClosedListener(any()); + // We expect one call from OpenWebSocketCounter, but none from AuthenticatedConnectListener itself + verify(webSocketSessionContext, times(1)).addWebsocketClosedListener(any()); verify(authenticatedWebSocketConnection, never()).start(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java index abe5eae6b..04021acf7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java @@ -11,7 +11,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.google.protobuf.InvalidProtocolBufferException; import java.time.Duration; @@ -24,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.ProvisioningManager; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.session.WebSocketSessionContext; @@ -40,7 +40,7 @@ class ProvisioningConnectListenerTest { provisioningManager = mock(ProvisioningManager.class); scheduledExecutorService = mock(ScheduledExecutorService.class); provisioningConnectListener = - new ProvisioningConnectListener(provisioningManager, scheduledExecutorService, TIMEOUT); + new ProvisioningConnectListener(provisioningManager, mock(ClientReleaseManager.class), scheduledExecutorService, TIMEOUT); } @Test