diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 44bd3f9c6..884c62642 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -1072,7 +1072,7 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000)); - provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, clientReleaseManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90))); + provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, asnInfoProviderSupplier, clientReleaseManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90))); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager)); provisioningEnvironment.jersey().register(new KeepAliveController(redisMessageAvailabilityManager)); provisioningEnvironment.jersey().register(new TimestampResponseFilter()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java index ab7e4e3e8..cf5f14690 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java @@ -2,14 +2,19 @@ package org.whispersystems.textsecuregcm.metrics; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; +import com.google.common.net.InetAddresses; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; +import java.net.InetSocketAddress; import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.asn.AsnInfo; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgent; @@ -18,11 +23,13 @@ import org.whispersystems.websocket.session.WebSocketSessionContext; public class OpenWebSocketCounter { + private final Supplier asnInfoProviderSupplier; private final ClientReleaseManager clientReleaseManager; private final Tags baseTags; private final Map openWebsocketsByTags; + private final Map openWebsocketsByAsnRegion = new ConcurrentHashMap<>(); private final AtomicInteger totalConnections; private static final int MAX_COUNTERS = 4096; @@ -33,10 +40,13 @@ public class OpenWebSocketCounter { private static final String WEB_SOCKET_CLOSED_COUNTER_NAME = name(OpenWebSocketCounter.class, "websocketClosed"); private static final String SESSION_DURATION_TIMER_NAME = name(OpenWebSocketCounter.class, "sessionDuration"); private static final String GAUGE_COUNT_GAUGE_NAME = name(OpenWebSocketCounter.class, "gaugeCount"); + private static final String OPEN_WEBSOCKET_BY_ASN_REGION_GAUGE_NAME = name(OpenWebSocketCounter.class, "openWebsocketsByAsnRegion"); public OpenWebSocketCounter(final String webSocketType, + final Supplier asnInfoProviderSupplier, final ClientReleaseManager clientReleaseManager) { + this.asnInfoProviderSupplier = asnInfoProviderSupplier; this.clientReleaseManager = clientReleaseManager; this.baseTags = Tags.of("webSocketType", webSocketType); @@ -48,6 +58,21 @@ public class OpenWebSocketCounter { public void countOpenWebSocket(final WebSocketSessionContext context) { final Timer.Sample sample = Timer.start(); + final Optional maybeOpenWebSocketsByAsnRegion; + + if (context.getClient().getRemoteAddress() instanceof InetSocketAddress inetSocketAddress) { + maybeOpenWebSocketsByAsnRegion = + asnInfoProviderSupplier.get().lookup(InetAddresses.toAddrString(inetSocketAddress.getAddress())) + .map(AsnInfo::regionCode) + .map(asnRegion -> openWebsocketsByAsnRegion.computeIfAbsent(asnRegion, region -> + Metrics.gauge(OPEN_WEBSOCKET_BY_ASN_REGION_GAUGE_NAME, Tags.of("asnRegion", region), + new AtomicInteger(0)))); + } else { + maybeOpenWebSocketsByAsnRegion = Optional.empty(); + } + + maybeOpenWebSocketsByAsnRegion.ifPresent(AtomicInteger::incrementAndGet); + @Nullable final UserAgent userAgent; { UserAgent parsedUserAgent; @@ -85,6 +110,7 @@ public class OpenWebSocketCounter { .register(Metrics.globalRegistry)); maybeOpenWebSocketCounter.ifPresent(AtomicInteger::decrementAndGet); + maybeOpenWebSocketsByAsnRegion.ifPresent(AtomicInteger::decrementAndGet); totalConnections.decrementAndGet(); Metrics.counter(WEB_SOCKET_CLOSED_COUNTER_NAME, tagsWithClientPlatform.and("status", String.valueOf(statusCode))) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 9f3bca729..da8206e82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -7,8 +7,10 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.common.annotations.VisibleForTesting; import java.util.Optional; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; @@ -55,12 +57,14 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final PushNotificationScheduler pushNotificationScheduler, final DisconnectionRequestManager disconnectionRequestManager, final Scheduler messageDeliveryScheduler, + final Supplier asnInfoProviderSupplier, final ClientReleaseManager clientReleaseManager, final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { this(accountsManager, disconnectionRequestManager, + asnInfoProviderSupplier, clientReleaseManager, (account, device, client) -> new WebSocketConnection(receiptSender, messagesManager, @@ -80,6 +84,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { @VisibleForTesting AuthenticatedConnectListener( final AccountsManager accountsManager, final DisconnectionRequestManager disconnectionRequestManager, + final Supplier asnInfoProviderSupplier, final ClientReleaseManager clientReleaseManager, final WebSocketConnectionBuilder webSocketConnectionBuilder) { @@ -87,8 +92,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.disconnectionRequestManager = disconnectionRequestManager; this.webSocketConnectionBuilder = webSocketConnectionBuilder; - this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", clientReleaseManager); - this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", clientReleaseManager); + this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", asnInfoProviderSupplier, clientReleaseManager); + this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", asnInfoProviderSupplier, clientReleaseManager); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index a76675087..1758a6416 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -14,6 +14,8 @@ import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.controllers.ProvisioningController; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -48,13 +50,14 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { private final Duration timeout; public ProvisioningConnectListener(final ProvisioningManager provisioningManager, + final Supplier asnInfoProviderSupplier, final ClientReleaseManager clientReleaseManager, final ScheduledExecutorService timeoutExecutor, final Duration timeout) { this.provisioningManager = provisioningManager; this.timeoutExecutor = timeoutExecutor; this.timeout = timeout; - this.openWebSocketCounter = new OpenWebSocketCounter("provisioning", clientReleaseManager); + this.openWebSocketCounter = new OpenWebSocketCounter("provisioning", asnInfoProviderSupplier, clientReleaseManager); } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java index de9ab693b..ef37d6b64 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/ProvisioningTimeoutIntegrationTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; @@ -109,7 +110,7 @@ public class ProvisioningTimeoutIntegrationTest { .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.setConnectListener( - new ProvisioningConnectListener(mock(ProvisioningManager.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5))); + new ProvisioningConnectListener(mock(ProvisioningManager.class), () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5))); final WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java index f6d0ff1ef..de7135691 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListenerTest.java @@ -21,6 +21,7 @@ import java.util.Optional; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -55,6 +56,7 @@ class AuthenticatedConnectListenerTest { authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager, disconnectionRequestManager, + () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), (_, _, _) -> authenticatedWebSocketConnection); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java index 04021acf7..486f3363d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListenerTest.java @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.asn.AsnInfoProvider; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; @@ -40,7 +41,7 @@ class ProvisioningConnectListenerTest { provisioningManager = mock(ProvisioningManager.class); scheduledExecutorService = mock(ScheduledExecutorService.class); provisioningConnectListener = - new ProvisioningConnectListener(provisioningManager, mock(ClientReleaseManager.class), scheduledExecutorService, TIMEOUT); + new ProvisioningConnectListener(provisioningManager, () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduledExecutorService, TIMEOUT); } @Test diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java index 4f6d0d6a4..1bcfd726d 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java @@ -5,6 +5,7 @@ package org.whispersystems.websocket; import com.google.common.net.HttpHeaders; +import java.net.SocketAddress; import java.nio.ByteBuffer; import java.security.SecureRandom; import java.time.Instant; @@ -105,4 +106,8 @@ public class WebSocketClient { private long generateRequestId() { return Math.abs(SECURE_RANDOM.nextLong()); } + + public SocketAddress getRemoteAddress() { + return session.getRemoteAddress(); + } }