Keep a count of open WebSockets by ASN region

This commit is contained in:
Jon Chambers
2026-03-09 14:55:28 -04:00
committed by Jon Chambers
parent e96149ecf5
commit 11df65b8d8
8 changed files with 50 additions and 7 deletions

View File

@@ -1072,7 +1072,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, disconnectionRequestManager,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager
messageDeliveryScheduler, asnInfoProviderSupplier, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager
));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
@@ -1150,7 +1150,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, clientReleaseManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90)));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, asnInfoProviderSupplier, clientReleaseManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90)));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(redisMessageAvailabilityManager));
provisioningEnvironment.jersey().register(new TimestampResponseFilter());

View File

@@ -2,14 +2,19 @@ package org.whispersystems.textsecuregcm.metrics;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.net.InetAddresses;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.asn.AsnInfo;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
@@ -18,11 +23,13 @@ import org.whispersystems.websocket.session.WebSocketSessionContext;
public class OpenWebSocketCounter {
private final Supplier<AsnInfoProvider> asnInfoProviderSupplier;
private final ClientReleaseManager clientReleaseManager;
private final Tags baseTags;
private final Map<Tags, AtomicInteger> openWebsocketsByTags;
private final Map<String, AtomicInteger> openWebsocketsByAsnRegion = new ConcurrentHashMap<>();
private final AtomicInteger totalConnections;
private static final int MAX_COUNTERS = 4096;
@@ -33,10 +40,13 @@ public class OpenWebSocketCounter {
private static final String WEB_SOCKET_CLOSED_COUNTER_NAME = name(OpenWebSocketCounter.class, "websocketClosed");
private static final String SESSION_DURATION_TIMER_NAME = name(OpenWebSocketCounter.class, "sessionDuration");
private static final String GAUGE_COUNT_GAUGE_NAME = name(OpenWebSocketCounter.class, "gaugeCount");
private static final String OPEN_WEBSOCKET_BY_ASN_REGION_GAUGE_NAME = name(OpenWebSocketCounter.class, "openWebsocketsByAsnRegion");
public OpenWebSocketCounter(final String webSocketType,
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
final ClientReleaseManager clientReleaseManager) {
this.asnInfoProviderSupplier = asnInfoProviderSupplier;
this.clientReleaseManager = clientReleaseManager;
this.baseTags = Tags.of("webSocketType", webSocketType);
@@ -48,6 +58,21 @@ public class OpenWebSocketCounter {
public void countOpenWebSocket(final WebSocketSessionContext context) {
final Timer.Sample sample = Timer.start();
final Optional<AtomicInteger> maybeOpenWebSocketsByAsnRegion;
if (context.getClient().getRemoteAddress() instanceof InetSocketAddress inetSocketAddress) {
maybeOpenWebSocketsByAsnRegion =
asnInfoProviderSupplier.get().lookup(InetAddresses.toAddrString(inetSocketAddress.getAddress()))
.map(AsnInfo::regionCode)
.map(asnRegion -> openWebsocketsByAsnRegion.computeIfAbsent(asnRegion, region ->
Metrics.gauge(OPEN_WEBSOCKET_BY_ASN_REGION_GAUGE_NAME, Tags.of("asnRegion", region),
new AtomicInteger(0))));
} else {
maybeOpenWebSocketsByAsnRegion = Optional.empty();
}
maybeOpenWebSocketsByAsnRegion.ifPresent(AtomicInteger::incrementAndGet);
@Nullable final UserAgent userAgent;
{
UserAgent parsedUserAgent;
@@ -85,6 +110,7 @@ public class OpenWebSocketCounter {
.register(Metrics.globalRegistry));
maybeOpenWebSocketCounter.ifPresent(AtomicInteger::decrementAndGet);
maybeOpenWebSocketsByAsnRegion.ifPresent(AtomicInteger::decrementAndGet);
totalConnections.decrementAndGet();
Metrics.counter(WEB_SOCKET_CLOSED_COUNTER_NAME, tagsWithClientPlatform.and("status", String.valueOf(statusCode)))

View File

@@ -7,8 +7,10 @@ package org.whispersystems.textsecuregcm.websocket;
import com.google.common.annotations.VisibleForTesting;
import java.util.Optional;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@@ -55,12 +57,14 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final PushNotificationScheduler pushNotificationScheduler,
final DisconnectionRequestManager disconnectionRequestManager,
final Scheduler messageDeliveryScheduler,
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
final ClientReleaseManager clientReleaseManager,
final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this(accountsManager,
disconnectionRequestManager,
asnInfoProviderSupplier,
clientReleaseManager,
(account, device, client) -> new WebSocketConnection(receiptSender,
messagesManager,
@@ -80,6 +84,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@VisibleForTesting AuthenticatedConnectListener(
final AccountsManager accountsManager,
final DisconnectionRequestManager disconnectionRequestManager,
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
final ClientReleaseManager clientReleaseManager,
final WebSocketConnectionBuilder webSocketConnectionBuilder) {
@@ -87,8 +92,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.disconnectionRequestManager = disconnectionRequestManager;
this.webSocketConnectionBuilder = webSocketConnectionBuilder;
this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", clientReleaseManager);
this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", clientReleaseManager);
this.openAuthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-authenticated", asnInfoProviderSupplier, clientReleaseManager);
this.openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter("rpc-unauthenticated", asnInfoProviderSupplier, clientReleaseManager);
}
@Override

View File

@@ -14,6 +14,8 @@ import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
@@ -48,13 +50,14 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
private final Duration timeout;
public ProvisioningConnectListener(final ProvisioningManager provisioningManager,
final Supplier<AsnInfoProvider> asnInfoProviderSupplier,
final ClientReleaseManager clientReleaseManager,
final ScheduledExecutorService timeoutExecutor,
final Duration timeout) {
this.provisioningManager = provisioningManager;
this.timeoutExecutor = timeoutExecutor;
this.timeout = timeout;
this.openWebSocketCounter = new OpenWebSocketCounter("provisioning", clientReleaseManager);
this.openWebSocketCounter = new OpenWebSocketCounter("provisioning", asnInfoProviderSupplier, clientReleaseManager);
}
@Override

View File

@@ -37,6 +37,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
@@ -109,7 +110,7 @@ public class ProvisioningTimeoutIntegrationTest {
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.setConnectListener(
new ProvisioningConnectListener(mock(ProvisioningManager.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5)));
new ProvisioningConnectListener(mock(ProvisioningManager.class), () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduler, Duration.ofSeconds(5)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,

View File

@@ -21,6 +21,7 @@ import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
@@ -55,6 +56,7 @@ class AuthenticatedConnectListenerTest {
authenticatedConnectListener = new AuthenticatedConnectListener(accountsManager,
disconnectionRequestManager,
() -> mock(AsnInfoProvider.class),
mock(ClientReleaseManager.class),
(_, _, _) -> authenticatedWebSocketConnection);

View File

@@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.asn.AsnInfoProvider;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@@ -40,7 +41,7 @@ class ProvisioningConnectListenerTest {
provisioningManager = mock(ProvisioningManager.class);
scheduledExecutorService = mock(ScheduledExecutorService.class);
provisioningConnectListener =
new ProvisioningConnectListener(provisioningManager, mock(ClientReleaseManager.class), scheduledExecutorService, TIMEOUT);
new ProvisioningConnectListener(provisioningManager, () -> mock(AsnInfoProvider.class), mock(ClientReleaseManager.class), scheduledExecutorService, TIMEOUT);
}
@Test

View File

@@ -5,6 +5,7 @@
package org.whispersystems.websocket;
import com.google.common.net.HttpHeaders;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.time.Instant;
@@ -105,4 +106,8 @@ public class WebSocketClient {
private long generateRequestId() {
return Math.abs(SECURE_RANDOM.nextLong());
}
public SocketAddress getRemoteAddress() {
return session.getRemoteAddress();
}
}