Add a bounded virtual executor service

This commit is contained in:
ravi-signal
2025-08-15 15:49:50 -05:00
committed by GitHub
parent c883cd8148
commit b76eaa1098
10 changed files with 412 additions and 21 deletions

View File

@@ -34,7 +34,7 @@ public class BufferingInterceptorIntegrationTest {
final TestController testController = new TestController();
environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-"));
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
}
}

View File

@@ -18,10 +18,17 @@ import jakarta.ws.rs.core.Response;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@@ -34,6 +41,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;

View File

@@ -0,0 +1,153 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class BoundedVirtualThreadFactoryTest {
private final static int MAX_THREADS = 10;
private BoundedVirtualThreadFactory factory;
@BeforeEach
void setUp() {
factory = new BoundedVirtualThreadFactory("test", MAX_THREADS);
}
@Test
void releaseWhenTaskThrows() throws InterruptedException {
final UncaughtExceptionHolder uncaughtExceptionHolder = new UncaughtExceptionHolder();
final Thread t = submit(() -> {
throw new IllegalArgumentException("test");
}, uncaughtExceptionHolder);
assertThat(t).isNotNull();
t.join(Duration.ofSeconds(1));
assertThat(uncaughtExceptionHolder.exception).isNotNull().isInstanceOf(IllegalArgumentException.class);
submitUntilRejected();
}
@Test
void releaseWhenRejected() throws InterruptedException {
submitUntilRejected();
submitUntilRejected();
}
@Test
void executorServiceRejectsAtLimit() throws InterruptedException {
try (final ExecutorService executor = Executors.newThreadPerTaskExecutor(factory)) {
final CountDownLatch latch = new CountDownLatch(1);
for (int i = 0; i < MAX_THREADS; i++) {
executor.submit(() -> {
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}
assertThatExceptionOfType(RejectedExecutionException.class).isThrownBy(() -> executor.submit(Util.NOOP));
latch.countDown();
executor.shutdown();
assertThat(executor.awaitTermination(5, TimeUnit.SECONDS)).isTrue();
}
}
@Test
void stressTest() throws InterruptedException {
for (int iteration = 0; iteration < 50; iteration++) {
final CountDownLatch latch = new CountDownLatch(MAX_THREADS);
// submit a task that submits a task maxThreads/2 times
final Thread[] threads = new Thread[MAX_THREADS];
for (int i = 0; i < MAX_THREADS; i+=2) {
int outerThreadIndex = i;
int innerThreadIndex = i + 1;
threads[outerThreadIndex] = submit(() -> {
latch.countDown();
threads[innerThreadIndex] = submit(latch::countDown);
});
}
latch.await();
// All threads must be created at this point, wait for them all to complete
for (Thread thread : threads) {
assertThat(thread).isNotNull();
thread.join();
}
assertThat(factory.getRunningThreads()).isEqualTo(0);
}
submitUntilRejected();
}
/**
* Verify we can submit up to the concurrency limit (and no more)
*/
private void submitUntilRejected() throws InterruptedException {
final CountDownLatch finish = new CountDownLatch(1);
final List<Thread> threads = IntStream.range(0, MAX_THREADS).mapToObj(_ -> submit(() -> {
try {
finish.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
})).toList();
assertThat(submit(Util.NOOP)).isNull();
finish.countDown();
for (Thread thread : threads) {
thread.join();
}
assertThat(factory.getRunningThreads()).isEqualTo(0);
}
private Thread submit(final Runnable runnable) {
return submit(runnable, (t, e) ->
fail("Uncaught exception on thread {} : {}", t, e));
}
private Thread submit(final Runnable runnable, final Thread.UncaughtExceptionHandler handler) {
final Thread thread = factory.newThread(runnable);
if (thread == null) {
return null;
}
if (handler != null) {
thread.setUncaughtExceptionHandler(handler);
}
thread.start();
return thread;
}
private static class UncaughtExceptionHolder implements Thread.UncaughtExceptionHandler {
Throwable exception = null;
@Override
public void uncaughtException(final Thread t, final Throwable e) {
exception = e;
}
}
}

View File

@@ -13,20 +13,37 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@ExtendWith(DropwizardExtensionsSupport.class)
class VirtualExecutorServiceProviderTest {
private static final ResourceExtension resources = ResourceExtension.builder()
private final TestController testController = new TestController();
private final ResourceExtension resources = ResourceExtension.builder()
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new VirtualExecutorServiceProvider("virtual-thread-"))
.addResource(new TestController())
.addProvider(new VirtualExecutorServiceProvider( "virtual-thread-", 2))
.addResource(testController)
.build();
@AfterEach
void setUp() {
testController.release();
}
@Test
public void testManagedAsyncThread() {
final Response response = resources.getJerseyTest()
@@ -37,6 +54,22 @@ class VirtualExecutorServiceProviderTest {
assertThat(threadName).startsWith("virtual-thread-");
}
@Test
public void testConcurrencyLimit() throws InterruptedException, TimeoutException {
final BlockingQueue<Response> responses = new LinkedBlockingQueue<>();
final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
for (int i = 0; i < 3; i++) {
executor.submit(() -> responses.offer(resources.getJerseyTest().target("/v1/test/await").request().get()));
}
final Response rejectedResponse = responses.poll(10, TimeUnit.SECONDS);
assertThat(rejectedResponse).isNotNull().extracting(Response::getStatus).isEqualTo(500);
assertThat(responses.isEmpty()).isTrue();
assertThat(testController.release()).isEqualTo(2);
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
}
@Test
public void testUnmanagedThread() {
final Response response = resources.getJerseyTest()
@@ -49,6 +82,7 @@ class VirtualExecutorServiceProviderTest {
@Path("/v1/test")
public static class TestController {
private List<CountDownLatch> latches = new ArrayList<>();
@GET
@Path("/managed-async")
@@ -57,12 +91,34 @@ class VirtualExecutorServiceProviderTest {
return Response.ok().entity(Thread.currentThread().getName()).build();
}
@GET
@Path("/await")
@ManagedAsync
public Response await() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
synchronized (this) {
latches.add(latch);
}
latch.await();
return Response.ok().build();
}
@GET
@Path("/unmanaged")
public Response unmanaged() {
return Response.ok().entity(Thread.currentThread().getName()).build();
}
synchronized int release() {
final Iterator<CountDownLatch> iterator = latches.iterator();
int count;
for (count = 0; iterator.hasNext(); count++) {
iterator.next().countDown();
iterator.remove();
}
return count;
}
}
public static class TestPrincipal implements Principal {