diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index b2bbf07db..856b4328d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -316,7 +316,7 @@ public class WhisperServerConfiguration extends Configuration { @Valid @NotNull @JsonProperty - private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(Duration.ofMillis(1)); + private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(); @Valid @NotNull diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index c21867fc7..546f097b0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -258,6 +258,7 @@ import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.util.BufferingInterceptor; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; +import org.whispersystems.textsecuregcm.util.ManagedExecutors; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; @@ -395,8 +396,10 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), Duration.ofMillis(90000)); - webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); + webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider( + "managed-async-websocket-virtual-thread", + config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor())); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter( config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/VirtualThreadConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/VirtualThreadConfiguration.java index 2e81aaf31..e4b2cca6e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/VirtualThreadConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/VirtualThreadConfiguration.java @@ -6,4 +6,20 @@ package org.whispersystems.textsecuregcm.configuration; import java.time.Duration; -public record VirtualThreadConfiguration(Duration pinEventThreshold) {} +public record VirtualThreadConfiguration( + Duration pinEventThreshold, + Integer maxConcurrentThreadsPerExecutor) { + + public VirtualThreadConfiguration() { + this(null, null); + } + + public VirtualThreadConfiguration { + if (maxConcurrentThreadsPerExecutor == null) { + maxConcurrentThreadsPerExecutor = 1_000_000; + } + if (pinEventThreshold == null) { + pinEventThreshold = Duration.ofMillis(1); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactory.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactory.java new file mode 100644 index 000000000..8cb415f50 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactory.java @@ -0,0 +1,94 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + +/** + * A thread factory that creates virtual threads but limits the total number of virtual threads created. + */ +public class BoundedVirtualThreadFactory implements ThreadFactory { + + private static final Logger logger = LoggerFactory.getLogger(BoundedVirtualThreadFactory.class); + + private final AtomicInteger runningThreads = new AtomicInteger(); + private final ThreadFactory delegate; + private final int maxConcurrentThreads; + + private final Counter created; + private final Counter completed; + + public BoundedVirtualThreadFactory(final String threadPoolName, final int maxConcurrentThreads) { + this.maxConcurrentThreads = maxConcurrentThreads; + + final Tags tags = Tags.of("pool", threadPoolName); + Metrics.gauge( + MetricsUtil.name(BoundedVirtualThreadFactory.class, "active"), + tags, runningThreads, (rt) -> (double) rt.get()); + this.created = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "created"), tags); + this.completed = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "completed"), tags); + + // The virtual thread factory will initialize thread names by appending the thread index to the provided prefix + this.delegate = Thread.ofVirtual().name(threadPoolName + "-", 0).factory(); + + } + + @Override + public Thread newThread(final Runnable r) { + if (!tryAcquire()) { + return null; + } + Thread thread = null; + try { + final Runnable wrapped = () -> { + try { + r.run(); + } finally { + release(); + } + }; + thread = delegate.newThread(wrapped); + } finally { + if (thread == null) { + release(); + } + } + return thread; + } + + + @VisibleForTesting + int getRunningThreads() { + return runningThreads.get(); + } + + private boolean tryAcquire() { + int old; + do { + old = runningThreads.get(); + if (old >= maxConcurrentThreads) { + return false; + } + } while (!runningThreads.compareAndSet(old, old + 1)); + created.increment(); + return true; + } + + private void release() { + int updated = runningThreads.decrementAndGet(); + if (updated < 0) { + logger.error("Released a thread and count was {}, which should never happen", updated); + } + completed.increment(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ManagedExecutors.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ManagedExecutors.java new file mode 100644 index 000000000..7ec494a8d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ManagedExecutors.java @@ -0,0 +1,40 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.dropwizard.core.setup.Environment; +import io.dropwizard.lifecycle.ExecutorServiceManager; +import io.dropwizard.util.Duration; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * Build Executor Services managed by dropwizard, supplementing executors provided by + * {@link io.dropwizard.lifecycle.setup.LifecycleEnvironment#executorService} + */ +public class ManagedExecutors { + + private static final Duration SHUTDOWN_DURATION = Duration.seconds(5); + + private ManagedExecutors() { + } + + public static ExecutorService newVirtualThreadPerTaskExecutor( + final String threadNamePrefix, + final int maxConcurrentThreads, + final Environment environment) { + + final BoundedVirtualThreadFactory threadFactory = + new BoundedVirtualThreadFactory(threadNamePrefix, maxConcurrentThreads); + final ExecutorService virtualThreadExecutor = Executors.newThreadPerTaskExecutor(threadFactory); + environment.lifecycle() + .manage(new ExecutorServiceManager(virtualThreadExecutor, SHUTDOWN_DURATION, threadNamePrefix)); + return virtualThreadExecutor; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java index db040c1ab..f8b866f7c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java @@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory; @ManagedAsyncExecutor public class VirtualExecutorServiceProvider implements ExecutorServiceProvider { + private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class); @@ -22,17 +23,24 @@ public class VirtualExecutorServiceProvider implements ExecutorServiceProvider { * Default thread pool executor termination timeout in milliseconds. */ public static final int TERMINATION_TIMEOUT = 5000; - private final String virtualThreadNamePrefix; - public VirtualExecutorServiceProvider(final String virtualThreadNamePrefix) { + private final String virtualThreadNamePrefix; + private final int maxConcurrentThreads; + + public VirtualExecutorServiceProvider( + final String virtualThreadNamePrefix, + final int maxConcurrentThreads) { this.virtualThreadNamePrefix = virtualThreadNamePrefix; + this.maxConcurrentThreads = maxConcurrentThreads; } @Override public ExecutorService getExecutorService() { logger.info("Creating executor service with virtual thread per task"); - return Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name(virtualThreadNamePrefix, 0).factory()); + final ExecutorService executor = Executors.newThreadPerTaskExecutor( + new BoundedVirtualThreadFactory(virtualThreadNamePrefix, maxConcurrentThreads)); + return executor; } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java index a56009ffd..4183026b3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java @@ -34,7 +34,7 @@ public class BufferingInterceptorIntegrationTest { final TestController testController = new TestController(); environment.jersey().register(testController); environment.jersey().register(new BufferingInterceptor()); - environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-")); + environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10)); JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java index b050d78db..910017759 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WhisperServerServiceTest.java @@ -18,10 +18,17 @@ import jakarta.ws.rs.core.Response; import java.net.URI; import java.util.List; import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; import org.eclipse.jetty.util.component.LifeCycle; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.StatusCode; import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -34,6 +41,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.HeaderUtils; +import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactoryTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactoryTest.java new file mode 100644 index 000000000..a8064aab4 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/BoundedVirtualThreadFactoryTest.java @@ -0,0 +1,153 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.fail; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class BoundedVirtualThreadFactoryTest { + + private final static int MAX_THREADS = 10; + + private BoundedVirtualThreadFactory factory; + + @BeforeEach + void setUp() { + factory = new BoundedVirtualThreadFactory("test", MAX_THREADS); + } + + @Test + void releaseWhenTaskThrows() throws InterruptedException { + final UncaughtExceptionHolder uncaughtExceptionHolder = new UncaughtExceptionHolder(); + final Thread t = submit(() -> { + throw new IllegalArgumentException("test"); + }, uncaughtExceptionHolder); + assertThat(t).isNotNull(); + t.join(Duration.ofSeconds(1)); + assertThat(uncaughtExceptionHolder.exception).isNotNull().isInstanceOf(IllegalArgumentException.class); + + submitUntilRejected(); + } + + @Test + void releaseWhenRejected() throws InterruptedException { + submitUntilRejected(); + submitUntilRejected(); + } + + @Test + void executorServiceRejectsAtLimit() throws InterruptedException { + try (final ExecutorService executor = Executors.newThreadPerTaskExecutor(factory)) { + + final CountDownLatch latch = new CountDownLatch(1); + for (int i = 0; i < MAX_THREADS; i++) { + executor.submit(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + assertThatExceptionOfType(RejectedExecutionException.class).isThrownBy(() -> executor.submit(Util.NOOP)); + latch.countDown(); + + executor.shutdown(); + assertThat(executor.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + } + + @Test + void stressTest() throws InterruptedException { + for (int iteration = 0; iteration < 50; iteration++) { + final CountDownLatch latch = new CountDownLatch(MAX_THREADS); + + // submit a task that submits a task maxThreads/2 times + final Thread[] threads = new Thread[MAX_THREADS]; + for (int i = 0; i < MAX_THREADS; i+=2) { + int outerThreadIndex = i; + int innerThreadIndex = i + 1; + + threads[outerThreadIndex] = submit(() -> { + latch.countDown(); + threads[innerThreadIndex] = submit(latch::countDown); + }); + } + latch.await(); + + // All threads must be created at this point, wait for them all to complete + for (Thread thread : threads) { + assertThat(thread).isNotNull(); + thread.join(); + } + + assertThat(factory.getRunningThreads()).isEqualTo(0); + } + + submitUntilRejected(); + } + + /** + * Verify we can submit up to the concurrency limit (and no more) + */ + private void submitUntilRejected() throws InterruptedException { + final CountDownLatch finish = new CountDownLatch(1); + final List threads = IntStream.range(0, MAX_THREADS).mapToObj(_ -> submit(() -> { + try { + finish.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + })).toList(); + + assertThat(submit(Util.NOOP)).isNull(); + + finish.countDown(); + + for (Thread thread : threads) { + thread.join(); + } + assertThat(factory.getRunningThreads()).isEqualTo(0); + } + + private Thread submit(final Runnable runnable) { + return submit(runnable, (t, e) -> + fail("Uncaught exception on thread {} : {}", t, e)); + } + + private Thread submit(final Runnable runnable, final Thread.UncaughtExceptionHandler handler) { + final Thread thread = factory.newThread(runnable); + if (thread == null) { + return null; + } + if (handler != null) { + thread.setUncaughtExceptionHandler(handler); + } + thread.start(); + return thread; + } + + private static class UncaughtExceptionHolder implements Thread.UncaughtExceptionHandler { + Throwable exception = null; + + @Override + public void uncaughtException(final Thread t, final Throwable e) { + exception = e; + } + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java index 2f848d4ab..bc3005e13 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java @@ -13,20 +13,37 @@ import jakarta.ws.rs.GET; import jakarta.ws.rs.Path; import jakarta.ws.rs.core.Response; import java.security.Principal; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @ExtendWith(DropwizardExtensionsSupport.class) class VirtualExecutorServiceProviderTest { - private static final ResourceExtension resources = ResourceExtension.builder() + private final TestController testController = new TestController(); + private final ResourceExtension resources = ResourceExtension.builder() .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addProvider(new VirtualExecutorServiceProvider("virtual-thread-")) - .addResource(new TestController()) + .addProvider(new VirtualExecutorServiceProvider( "virtual-thread-", 2)) + .addResource(testController) .build(); + @AfterEach + void setUp() { + testController.release(); + } + @Test public void testManagedAsyncThread() { final Response response = resources.getJerseyTest() @@ -37,6 +54,22 @@ class VirtualExecutorServiceProviderTest { assertThat(threadName).startsWith("virtual-thread-"); } + @Test + public void testConcurrencyLimit() throws InterruptedException, TimeoutException { + final BlockingQueue responses = new LinkedBlockingQueue<>(); + final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor(); + for (int i = 0; i < 3; i++) { + executor.submit(() -> responses.offer(resources.getJerseyTest().target("/v1/test/await").request().get())); + } + final Response rejectedResponse = responses.poll(10, TimeUnit.SECONDS); + assertThat(rejectedResponse).isNotNull().extracting(Response::getStatus).isEqualTo(500); + + assertThat(responses.isEmpty()).isTrue(); + assertThat(testController.release()).isEqualTo(2); + assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200); + assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200); + } + @Test public void testUnmanagedThread() { final Response response = resources.getJerseyTest() @@ -49,6 +82,7 @@ class VirtualExecutorServiceProviderTest { @Path("/v1/test") public static class TestController { + private List latches = new ArrayList<>(); @GET @Path("/managed-async") @@ -57,12 +91,34 @@ class VirtualExecutorServiceProviderTest { return Response.ok().entity(Thread.currentThread().getName()).build(); } + @GET + @Path("/await") + @ManagedAsync + public Response await() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + synchronized (this) { + latches.add(latch); + } + latch.await(); + return Response.ok().build(); + } + @GET @Path("/unmanaged") public Response unmanaged() { return Response.ok().entity(Thread.currentThread().getName()).build(); } + synchronized int release() { + final Iterator iterator = latches.iterator(); + int count; + for (count = 0; iterator.hasNext(); count++) { + iterator.next().countDown(); + iterator.remove(); + } + return count; + } + } public static class TestPrincipal implements Principal {