Add a bounded virtual executor service

This commit is contained in:
ravi-signal
2025-08-15 15:49:50 -05:00
committed by GitHub
parent c883cd8148
commit b76eaa1098
10 changed files with 412 additions and 21 deletions

View File

@@ -316,7 +316,7 @@ public class WhisperServerConfiguration extends Configuration {
@Valid @Valid
@NotNull @NotNull
@JsonProperty @JsonProperty
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(Duration.ofMillis(1)); private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration();
@Valid @Valid
@NotNull @NotNull

View File

@@ -258,6 +258,7 @@ import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.util.BufferingInterceptor; import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import org.whispersystems.textsecuregcm.util.ManagedExecutors;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
@@ -395,8 +396,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(new ManagedAwsCrt()); environment.lifecycle().manage(new ManagedAwsCrt());
final ExecutorService awsSdkMetricsExecutor = environment.lifecycle() final ExecutorService awsSdkMetricsExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
.virtualExecutorService(name(getClass(), "awsSdkMetrics-%d")); "awsSdkMetrics",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
final DynamoDbAsyncClient dynamoDbAsyncClient = config.getDynamoDbClientConfiguration() final DynamoDbAsyncClient dynamoDbAsyncClient = config.getDynamoDbClientConfiguration()
.buildAsyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbAsync")); .buildAsyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbAsync"));
@@ -561,14 +564,23 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.maxThreads(2) .maxThreads(2)
.minThreads(2) .minThreads(2)
.build(); .build();
ExecutorService googlePlayBillingExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d")); ExecutorService googlePlayBillingExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
ExecutorService appleAppStoreExecutor = environment.lifecycle() "googlePlayBilling",
.virtualExecutorService(name(getClass(), "appleAppStore-%d")); config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
ExecutorService clientEventExecutor = environment.lifecycle() environment);
.virtualExecutorService(name(getClass(), "clientEvent-%d")); ExecutorService appleAppStoreExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
ExecutorService disconnectionRequestListenerExecutor = environment.lifecycle() "appleAppStore",
.virtualExecutorService(name(getClass(), "disconnectionRequest-%d")); config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ExecutorService clientEventExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"clientEvent",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ExecutorService disconnectionRequestListenerExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
"disconnectionRequest",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
environment);
ScheduledExecutorService appleAppStoreRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "appleAppStoreRetry").threads(1).build(); ScheduledExecutorService appleAppStoreRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "appleAppStoreRetry").threads(1).build();
ScheduledExecutorService subscriptionProcessorRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "subscriptionProcessorRetry").threads(1).build(); ScheduledExecutorService subscriptionProcessorRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "subscriptionProcessorRetry").threads(1).build();
@@ -976,7 +988,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new BufferingInterceptor()); environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager)); environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager));
environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-")); environment.jersey().register(new VirtualExecutorServiceProvider(
"managed-async-virtual-thread",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
environment.jersey().register(new RateLimitByIpFilter(rateLimiters)); environment.jersey().register(new RateLimitByIpFilter(rateLimiters));
environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP)); environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP));
environment.jersey().register(MultiRecipientMessageProvider.class); environment.jersey().register(MultiRecipientMessageProvider.class);
@@ -987,7 +1001,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
/// ///
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment, WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000)); config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider(
"managed-async-websocket-virtual-thread",
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter( webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC())); config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));

View File

@@ -6,4 +6,20 @@ package org.whispersystems.textsecuregcm.configuration;
import java.time.Duration; import java.time.Duration;
public record VirtualThreadConfiguration(Duration pinEventThreshold) {} public record VirtualThreadConfiguration(
Duration pinEventThreshold,
Integer maxConcurrentThreadsPerExecutor) {
public VirtualThreadConfiguration() {
this(null, null);
}
public VirtualThreadConfiguration {
if (maxConcurrentThreadsPerExecutor == null) {
maxConcurrentThreadsPerExecutor = 1_000_000;
}
if (pinEventThreshold == null) {
pinEventThreshold = Duration.ofMillis(1);
}
}
}

View File

@@ -0,0 +1,94 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* A thread factory that creates virtual threads but limits the total number of virtual threads created.
*/
public class BoundedVirtualThreadFactory implements ThreadFactory {
private static final Logger logger = LoggerFactory.getLogger(BoundedVirtualThreadFactory.class);
private final AtomicInteger runningThreads = new AtomicInteger();
private final ThreadFactory delegate;
private final int maxConcurrentThreads;
private final Counter created;
private final Counter completed;
public BoundedVirtualThreadFactory(final String threadPoolName, final int maxConcurrentThreads) {
this.maxConcurrentThreads = maxConcurrentThreads;
final Tags tags = Tags.of("pool", threadPoolName);
Metrics.gauge(
MetricsUtil.name(BoundedVirtualThreadFactory.class, "active"),
tags, runningThreads, (rt) -> (double) rt.get());
this.created = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "created"), tags);
this.completed = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "completed"), tags);
// The virtual thread factory will initialize thread names by appending the thread index to the provided prefix
this.delegate = Thread.ofVirtual().name(threadPoolName + "-", 0).factory();
}
@Override
public Thread newThread(final Runnable r) {
if (!tryAcquire()) {
return null;
}
Thread thread = null;
try {
final Runnable wrapped = () -> {
try {
r.run();
} finally {
release();
}
};
thread = delegate.newThread(wrapped);
} finally {
if (thread == null) {
release();
}
}
return thread;
}
@VisibleForTesting
int getRunningThreads() {
return runningThreads.get();
}
private boolean tryAcquire() {
int old;
do {
old = runningThreads.get();
if (old >= maxConcurrentThreads) {
return false;
}
} while (!runningThreads.compareAndSet(old, old + 1));
created.increment();
return true;
}
private void release() {
int updated = runningThreads.decrementAndGet();
if (updated < 0) {
logger.error("Released a thread and count was {}, which should never happen", updated);
}
completed.increment();
}
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.lifecycle.ExecutorServiceManager;
import io.dropwizard.util.Duration;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Build Executor Services managed by dropwizard, supplementing executors provided by
* {@link io.dropwizard.lifecycle.setup.LifecycleEnvironment#executorService}
*/
public class ManagedExecutors {
private static final Duration SHUTDOWN_DURATION = Duration.seconds(5);
private ManagedExecutors() {
}
public static ExecutorService newVirtualThreadPerTaskExecutor(
final String threadNamePrefix,
final int maxConcurrentThreads,
final Environment environment) {
final BoundedVirtualThreadFactory threadFactory =
new BoundedVirtualThreadFactory(threadNamePrefix, maxConcurrentThreads);
final ExecutorService virtualThreadExecutor = Executors.newThreadPerTaskExecutor(threadFactory);
environment.lifecycle()
.manage(new ExecutorServiceManager(virtualThreadExecutor, SHUTDOWN_DURATION, threadNamePrefix));
return virtualThreadExecutor;
}
}

View File

@@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory;
@ManagedAsyncExecutor @ManagedAsyncExecutor
public class VirtualExecutorServiceProvider implements ExecutorServiceProvider { public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class); private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class);
@@ -22,17 +23,24 @@ public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
* Default thread pool executor termination timeout in milliseconds. * Default thread pool executor termination timeout in milliseconds.
*/ */
public static final int TERMINATION_TIMEOUT = 5000; public static final int TERMINATION_TIMEOUT = 5000;
private final String virtualThreadNamePrefix;
public VirtualExecutorServiceProvider(final String virtualThreadNamePrefix) { private final String virtualThreadNamePrefix;
private final int maxConcurrentThreads;
public VirtualExecutorServiceProvider(
final String virtualThreadNamePrefix,
final int maxConcurrentThreads) {
this.virtualThreadNamePrefix = virtualThreadNamePrefix; this.virtualThreadNamePrefix = virtualThreadNamePrefix;
this.maxConcurrentThreads = maxConcurrentThreads;
} }
@Override @Override
public ExecutorService getExecutorService() { public ExecutorService getExecutorService() {
logger.info("Creating executor service with virtual thread per task"); logger.info("Creating executor service with virtual thread per task");
return Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name(virtualThreadNamePrefix, 0).factory()); final ExecutorService executor = Executors.newThreadPerTaskExecutor(
new BoundedVirtualThreadFactory(virtualThreadNamePrefix, maxConcurrentThreads));
return executor;
} }
@Override @Override

View File

@@ -34,7 +34,7 @@ public class BufferingInterceptorIntegrationTest {
final TestController testController = new TestController(); final TestController testController = new TestController();
environment.jersey().register(testController); environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor()); environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-")); environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
} }
} }

View File

@@ -18,10 +18,17 @@ import jakarta.ws.rs.core.Response;
import java.net.URI; import java.net.URI;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import org.eclipse.jetty.util.component.LifeCycle; import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode; import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -34,6 +41,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;

View File

@@ -0,0 +1,153 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class BoundedVirtualThreadFactoryTest {
private final static int MAX_THREADS = 10;
private BoundedVirtualThreadFactory factory;
@BeforeEach
void setUp() {
factory = new BoundedVirtualThreadFactory("test", MAX_THREADS);
}
@Test
void releaseWhenTaskThrows() throws InterruptedException {
final UncaughtExceptionHolder uncaughtExceptionHolder = new UncaughtExceptionHolder();
final Thread t = submit(() -> {
throw new IllegalArgumentException("test");
}, uncaughtExceptionHolder);
assertThat(t).isNotNull();
t.join(Duration.ofSeconds(1));
assertThat(uncaughtExceptionHolder.exception).isNotNull().isInstanceOf(IllegalArgumentException.class);
submitUntilRejected();
}
@Test
void releaseWhenRejected() throws InterruptedException {
submitUntilRejected();
submitUntilRejected();
}
@Test
void executorServiceRejectsAtLimit() throws InterruptedException {
try (final ExecutorService executor = Executors.newThreadPerTaskExecutor(factory)) {
final CountDownLatch latch = new CountDownLatch(1);
for (int i = 0; i < MAX_THREADS; i++) {
executor.submit(() -> {
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
}
assertThatExceptionOfType(RejectedExecutionException.class).isThrownBy(() -> executor.submit(Util.NOOP));
latch.countDown();
executor.shutdown();
assertThat(executor.awaitTermination(5, TimeUnit.SECONDS)).isTrue();
}
}
@Test
void stressTest() throws InterruptedException {
for (int iteration = 0; iteration < 50; iteration++) {
final CountDownLatch latch = new CountDownLatch(MAX_THREADS);
// submit a task that submits a task maxThreads/2 times
final Thread[] threads = new Thread[MAX_THREADS];
for (int i = 0; i < MAX_THREADS; i+=2) {
int outerThreadIndex = i;
int innerThreadIndex = i + 1;
threads[outerThreadIndex] = submit(() -> {
latch.countDown();
threads[innerThreadIndex] = submit(latch::countDown);
});
}
latch.await();
// All threads must be created at this point, wait for them all to complete
for (Thread thread : threads) {
assertThat(thread).isNotNull();
thread.join();
}
assertThat(factory.getRunningThreads()).isEqualTo(0);
}
submitUntilRejected();
}
/**
* Verify we can submit up to the concurrency limit (and no more)
*/
private void submitUntilRejected() throws InterruptedException {
final CountDownLatch finish = new CountDownLatch(1);
final List<Thread> threads = IntStream.range(0, MAX_THREADS).mapToObj(_ -> submit(() -> {
try {
finish.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
})).toList();
assertThat(submit(Util.NOOP)).isNull();
finish.countDown();
for (Thread thread : threads) {
thread.join();
}
assertThat(factory.getRunningThreads()).isEqualTo(0);
}
private Thread submit(final Runnable runnable) {
return submit(runnable, (t, e) ->
fail("Uncaught exception on thread {} : {}", t, e));
}
private Thread submit(final Runnable runnable, final Thread.UncaughtExceptionHandler handler) {
final Thread thread = factory.newThread(runnable);
if (thread == null) {
return null;
}
if (handler != null) {
thread.setUncaughtExceptionHandler(handler);
}
thread.start();
return thread;
}
private static class UncaughtExceptionHolder implements Thread.UncaughtExceptionHandler {
Throwable exception = null;
@Override
public void uncaughtException(final Thread t, final Throwable e) {
exception = e;
}
}
}

View File

@@ -13,20 +13,37 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.security.Principal; import java.security.Principal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class VirtualExecutorServiceProviderTest { class VirtualExecutorServiceProviderTest {
private static final ResourceExtension resources = ResourceExtension.builder() private final TestController testController = new TestController();
private final ResourceExtension resources = ResourceExtension.builder()
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new VirtualExecutorServiceProvider("virtual-thread-")) .addProvider(new VirtualExecutorServiceProvider( "virtual-thread-", 2))
.addResource(new TestController()) .addResource(testController)
.build(); .build();
@AfterEach
void setUp() {
testController.release();
}
@Test @Test
public void testManagedAsyncThread() { public void testManagedAsyncThread() {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
@@ -37,6 +54,22 @@ class VirtualExecutorServiceProviderTest {
assertThat(threadName).startsWith("virtual-thread-"); assertThat(threadName).startsWith("virtual-thread-");
} }
@Test
public void testConcurrencyLimit() throws InterruptedException, TimeoutException {
final BlockingQueue<Response> responses = new LinkedBlockingQueue<>();
final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
for (int i = 0; i < 3; i++) {
executor.submit(() -> responses.offer(resources.getJerseyTest().target("/v1/test/await").request().get()));
}
final Response rejectedResponse = responses.poll(10, TimeUnit.SECONDS);
assertThat(rejectedResponse).isNotNull().extracting(Response::getStatus).isEqualTo(500);
assertThat(responses.isEmpty()).isTrue();
assertThat(testController.release()).isEqualTo(2);
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
}
@Test @Test
public void testUnmanagedThread() { public void testUnmanagedThread() {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
@@ -49,6 +82,7 @@ class VirtualExecutorServiceProviderTest {
@Path("/v1/test") @Path("/v1/test")
public static class TestController { public static class TestController {
private List<CountDownLatch> latches = new ArrayList<>();
@GET @GET
@Path("/managed-async") @Path("/managed-async")
@@ -57,12 +91,34 @@ class VirtualExecutorServiceProviderTest {
return Response.ok().entity(Thread.currentThread().getName()).build(); return Response.ok().entity(Thread.currentThread().getName()).build();
} }
@GET
@Path("/await")
@ManagedAsync
public Response await() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
synchronized (this) {
latches.add(latch);
}
latch.await();
return Response.ok().build();
}
@GET @GET
@Path("/unmanaged") @Path("/unmanaged")
public Response unmanaged() { public Response unmanaged() {
return Response.ok().entity(Thread.currentThread().getName()).build(); return Response.ok().entity(Thread.currentThread().getName()).build();
} }
synchronized int release() {
final Iterator<CountDownLatch> iterator = latches.iterator();
int count;
for (count = 0; iterator.hasNext(); count++) {
iterator.next().countDown();
iterator.remove();
}
return count;
}
} }
public static class TestPrincipal implements Principal { public static class TestPrincipal implements Principal {