Add a bounded virtual executor service

This commit is contained in:
ravi-signal
2025-08-15 15:49:50 -05:00
committed by GitHub
parent c883cd8148
commit b76eaa1098
10 changed files with 412 additions and 21 deletions

View File

@@ -316,7 +316,7 @@ public class WhisperServerConfiguration extends Configuration {
@Valid
@NotNull
@JsonProperty
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(Duration.ofMillis(1));
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration();
@Valid
@NotNull

View File

@@ -258,6 +258,7 @@ import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import org.whispersystems.textsecuregcm.util.ManagedExecutors;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
@@ -395,8 +396,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(new ManagedAwsCrt());
final ExecutorService awsSdkMetricsExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "awsSdkMetrics-%d"));
final ExecutorService awsSdkMetricsExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"awsSdkMetrics",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
final DynamoDbAsyncClient dynamoDbAsyncClient = config.getDynamoDbClientConfiguration()
.buildAsyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbAsync"));
@@ -561,14 +564,23 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.maxThreads(2)
.minThreads(2)
.build();
ExecutorService googlePlayBillingExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d"));
ExecutorService appleAppStoreExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "appleAppStore-%d"));
ExecutorService clientEventExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "clientEvent-%d"));
ExecutorService disconnectionRequestListenerExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "disconnectionRequest-%d"));
ExecutorService googlePlayBillingExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"googlePlayBilling",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ExecutorService appleAppStoreExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"appleAppStore",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ExecutorService clientEventExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"clientEvent",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ExecutorService disconnectionRequestListenerExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"disconnectionRequest",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ScheduledExecutorService appleAppStoreRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "appleAppStoreRetry").threads(1).build();
ScheduledExecutorService subscriptionProcessorRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "subscriptionProcessorRetry").threads(1).build();
@@ -976,7 +988,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager));
environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-"));
environment.jersey().register(new VirtualExecutorServiceProvider(
"managed-async-virtual-thread",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
environment.jersey().register(new RateLimitByIpFilter(rateLimiters));
environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP));
environment.jersey().register(MultiRecipientMessageProvider.class);
@@ -987,7 +1001,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
///
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider(
"managed-async-websocket-virtual-thread",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));

View File

@@ -6,4 +6,20 @@ package org.whispersystems.textsecuregcm.configuration;
import java.time.Duration;
public record VirtualThreadConfiguration(Duration pinEventThreshold) {}
public record VirtualThreadConfiguration(
Duration pinEventThreshold,
Integer maxConcurrentThreadsPerExecutor) {
public VirtualThreadConfiguration() {
this(null, null);
}
public VirtualThreadConfiguration {
if (maxConcurrentThreadsPerExecutor == null) {
maxConcurrentThreadsPerExecutor = 1_000_000;
}
if (pinEventThreshold == null) {
pinEventThreshold = Duration.ofMillis(1);
}
}
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* A thread factory that creates virtual threads but limits the total number of virtual threads created.
*/
public class BoundedVirtualThreadFactory implements ThreadFactory {
private static final Logger logger = LoggerFactory.getLogger(BoundedVirtualThreadFactory.class);
private final AtomicInteger runningThreads = new AtomicInteger();
private final ThreadFactory delegate;
private final int maxConcurrentThreads;
private final Counter created;
private final Counter completed;
public BoundedVirtualThreadFactory(final String threadPoolName, final int maxConcurrentThreads) {
this.maxConcurrentThreads = maxConcurrentThreads;
final Tags tags = Tags.of("pool", threadPoolName);
Metrics.gauge(
MetricsUtil.name(BoundedVirtualThreadFactory.class, "active"),
tags, runningThreads, (rt) -> (double) rt.get());
this.created = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "created"), tags);
this.completed = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "completed"), tags);
// The virtual thread factory will initialize thread names by appending the thread index to the provided prefix
this.delegate = Thread.ofVirtual().name(threadPoolName + "-", 0).factory();
}
@Override
public Thread newThread(final Runnable r) {
if (!tryAcquire()) {
return null;
}
Thread thread = null;
try {
final Runnable wrapped = () -> {
try {
r.run();
} finally {
release();
}
};
thread = delegate.newThread(wrapped);
} finally {
if (thread == null) {
release();
}
}
return thread;
}
@VisibleForTesting
int getRunningThreads() {
return runningThreads.get();
}
private boolean tryAcquire() {
int old;
do {
old = runningThreads.get();
if (old >= maxConcurrentThreads) {
return false;
}
} while (!runningThreads.compareAndSet(old, old + 1));
created.increment();
return true;
}
private void release() {
int updated = runningThreads.decrementAndGet();
if (updated < 0) {
logger.error("Released a thread and count was {}, which should never happen", updated);
}
completed.increment();
}
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.lifecycle.ExecutorServiceManager;
import io.dropwizard.util.Duration;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Build Executor Services managed by dropwizard, supplementing executors provided by
* {@link io.dropwizard.lifecycle.setup.LifecycleEnvironment#executorService}
*/
public class ManagedExecutors {
private static final Duration SHUTDOWN_DURATION = Duration.seconds(5);
private ManagedExecutors() {
}
public static ExecutorService newVirtualThreadPerTaskExecutor(
final String threadNamePrefix,
final int maxConcurrentThreads,
final Environment environment) {
final BoundedVirtualThreadFactory threadFactory =
new BoundedVirtualThreadFactory(threadNamePrefix, maxConcurrentThreads);
final ExecutorService virtualThreadExecutor = Executors.newThreadPerTaskExecutor(threadFactory);
environment.lifecycle()
.manage(new ExecutorServiceManager(virtualThreadExecutor, SHUTDOWN_DURATION, threadNamePrefix));
return virtualThreadExecutor;
}
}

View File

@@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory;
@ManagedAsyncExecutor
public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class);
@@ -22,17 +23,24 @@ public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
* Default thread pool executor termination timeout in milliseconds.
*/
public static final int TERMINATION_TIMEOUT = 5000;
private final String virtualThreadNamePrefix;
public VirtualExecutorServiceProvider(final String virtualThreadNamePrefix) {
private final String virtualThreadNamePrefix;
private final int maxConcurrentThreads;
public VirtualExecutorServiceProvider(
final String virtualThreadNamePrefix,
final int maxConcurrentThreads) {
this.virtualThreadNamePrefix = virtualThreadNamePrefix;
this.maxConcurrentThreads = maxConcurrentThreads;
}
@Override
public ExecutorService getExecutorService() {
logger.info("Creating executor service with virtual thread per task");
return Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name(virtualThreadNamePrefix, 0).factory());
final ExecutorService executor = Executors.newThreadPerTaskExecutor(
new BoundedVirtualThreadFactory(virtualThreadNamePrefix, maxConcurrentThreads));
return executor;
}
@Override