mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-18 05:45:21 +01:00
Add a bounded virtual executor service
This commit is contained in:
@@ -316,7 +316,7 @@ public class WhisperServerConfiguration extends Configuration {
|
|||||||
@Valid
|
@Valid
|
||||||
@NotNull
|
@NotNull
|
||||||
@JsonProperty
|
@JsonProperty
|
||||||
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(Duration.ofMillis(1));
|
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration();
|
||||||
|
|
||||||
@Valid
|
@Valid
|
||||||
@NotNull
|
@NotNull
|
||||||
|
|||||||
@@ -258,6 +258,7 @@ import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager;
|
|||||||
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
|
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
|
||||||
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
|
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
|
||||||
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
|
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
|
||||||
|
import org.whispersystems.textsecuregcm.util.ManagedExecutors;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
||||||
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
|
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
|
||||||
@@ -395,8 +396,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||||||
|
|
||||||
environment.lifecycle().manage(new ManagedAwsCrt());
|
environment.lifecycle().manage(new ManagedAwsCrt());
|
||||||
|
|
||||||
final ExecutorService awsSdkMetricsExecutor = environment.lifecycle()
|
final ExecutorService awsSdkMetricsExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
|
||||||
.virtualExecutorService(name(getClass(), "awsSdkMetrics-%d"));
|
"awsSdkMetrics",
|
||||||
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
|
||||||
|
environment);
|
||||||
|
|
||||||
final DynamoDbAsyncClient dynamoDbAsyncClient = config.getDynamoDbClientConfiguration()
|
final DynamoDbAsyncClient dynamoDbAsyncClient = config.getDynamoDbClientConfiguration()
|
||||||
.buildAsyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbAsync"));
|
.buildAsyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbAsync"));
|
||||||
@@ -561,14 +564,23 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||||||
.maxThreads(2)
|
.maxThreads(2)
|
||||||
.minThreads(2)
|
.minThreads(2)
|
||||||
.build();
|
.build();
|
||||||
ExecutorService googlePlayBillingExecutor = environment.lifecycle()
|
|
||||||
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d"));
|
ExecutorService googlePlayBillingExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
|
||||||
ExecutorService appleAppStoreExecutor = environment.lifecycle()
|
"googlePlayBilling",
|
||||||
.virtualExecutorService(name(getClass(), "appleAppStore-%d"));
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
|
||||||
ExecutorService clientEventExecutor = environment.lifecycle()
|
environment);
|
||||||
.virtualExecutorService(name(getClass(), "clientEvent-%d"));
|
ExecutorService appleAppStoreExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
|
||||||
ExecutorService disconnectionRequestListenerExecutor = environment.lifecycle()
|
"appleAppStore",
|
||||||
.virtualExecutorService(name(getClass(), "disconnectionRequest-%d"));
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
|
||||||
|
environment);
|
||||||
|
ExecutorService clientEventExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
|
||||||
|
"clientEvent",
|
||||||
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
|
||||||
|
environment);
|
||||||
|
ExecutorService disconnectionRequestListenerExecutor = ManagedExecutors.newVirtualThreadPerTaskExecutor(
|
||||||
|
"disconnectionRequest",
|
||||||
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor(),
|
||||||
|
environment);
|
||||||
|
|
||||||
ScheduledExecutorService appleAppStoreRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "appleAppStoreRetry").threads(1).build();
|
ScheduledExecutorService appleAppStoreRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "appleAppStoreRetry").threads(1).build();
|
||||||
ScheduledExecutorService subscriptionProcessorRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "subscriptionProcessorRetry").threads(1).build();
|
ScheduledExecutorService subscriptionProcessorRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "subscriptionProcessorRetry").threads(1).build();
|
||||||
@@ -976,7 +988,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||||||
environment.jersey().register(new BufferingInterceptor());
|
environment.jersey().register(new BufferingInterceptor());
|
||||||
environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager));
|
environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager));
|
||||||
|
|
||||||
environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-"));
|
environment.jersey().register(new VirtualExecutorServiceProvider(
|
||||||
|
"managed-async-virtual-thread",
|
||||||
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
|
||||||
environment.jersey().register(new RateLimitByIpFilter(rateLimiters));
|
environment.jersey().register(new RateLimitByIpFilter(rateLimiters));
|
||||||
environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP));
|
environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP));
|
||||||
environment.jersey().register(MultiRecipientMessageProvider.class);
|
environment.jersey().register(MultiRecipientMessageProvider.class);
|
||||||
@@ -987,7 +1001,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||||||
///
|
///
|
||||||
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
|
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
|
||||||
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
|
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
|
||||||
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
|
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider(
|
||||||
|
"managed-async-websocket-virtual-thread",
|
||||||
|
config.getVirtualThreadConfiguration().maxConcurrentThreadsPerExecutor()));
|
||||||
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
|
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
|
||||||
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
|
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
|
||||||
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
|
config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
|
||||||
|
|||||||
@@ -6,4 +6,20 @@ package org.whispersystems.textsecuregcm.configuration;
|
|||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
|
||||||
public record VirtualThreadConfiguration(Duration pinEventThreshold) {}
|
public record VirtualThreadConfiguration(
|
||||||
|
Duration pinEventThreshold,
|
||||||
|
Integer maxConcurrentThreadsPerExecutor) {
|
||||||
|
|
||||||
|
public VirtualThreadConfiguration() {
|
||||||
|
this(null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public VirtualThreadConfiguration {
|
||||||
|
if (maxConcurrentThreadsPerExecutor == null) {
|
||||||
|
maxConcurrentThreadsPerExecutor = 1_000_000;
|
||||||
|
}
|
||||||
|
if (pinEventThreshold == null) {
|
||||||
|
pinEventThreshold = Duration.ofMillis(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import io.micrometer.core.instrument.Counter;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import io.micrometer.core.instrument.Tags;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A thread factory that creates virtual threads but limits the total number of virtual threads created.
|
||||||
|
*/
|
||||||
|
public class BoundedVirtualThreadFactory implements ThreadFactory {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(BoundedVirtualThreadFactory.class);
|
||||||
|
|
||||||
|
private final AtomicInteger runningThreads = new AtomicInteger();
|
||||||
|
private final ThreadFactory delegate;
|
||||||
|
private final int maxConcurrentThreads;
|
||||||
|
|
||||||
|
private final Counter created;
|
||||||
|
private final Counter completed;
|
||||||
|
|
||||||
|
public BoundedVirtualThreadFactory(final String threadPoolName, final int maxConcurrentThreads) {
|
||||||
|
this.maxConcurrentThreads = maxConcurrentThreads;
|
||||||
|
|
||||||
|
final Tags tags = Tags.of("pool", threadPoolName);
|
||||||
|
Metrics.gauge(
|
||||||
|
MetricsUtil.name(BoundedVirtualThreadFactory.class, "active"),
|
||||||
|
tags, runningThreads, (rt) -> (double) rt.get());
|
||||||
|
this.created = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "created"), tags);
|
||||||
|
this.completed = Metrics.counter(MetricsUtil.name(BoundedVirtualThreadFactory.class, "completed"), tags);
|
||||||
|
|
||||||
|
// The virtual thread factory will initialize thread names by appending the thread index to the provided prefix
|
||||||
|
this.delegate = Thread.ofVirtual().name(threadPoolName + "-", 0).factory();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Thread newThread(final Runnable r) {
|
||||||
|
if (!tryAcquire()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Thread thread = null;
|
||||||
|
try {
|
||||||
|
final Runnable wrapped = () -> {
|
||||||
|
try {
|
||||||
|
r.run();
|
||||||
|
} finally {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
thread = delegate.newThread(wrapped);
|
||||||
|
} finally {
|
||||||
|
if (thread == null) {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return thread;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
int getRunningThreads() {
|
||||||
|
return runningThreads.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean tryAcquire() {
|
||||||
|
int old;
|
||||||
|
do {
|
||||||
|
old = runningThreads.get();
|
||||||
|
if (old >= maxConcurrentThreads) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} while (!runningThreads.compareAndSet(old, old + 1));
|
||||||
|
created.increment();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void release() {
|
||||||
|
int updated = runningThreads.decrementAndGet();
|
||||||
|
if (updated < 0) {
|
||||||
|
logger.error("Released a thread and count was {}, which should never happen", updated);
|
||||||
|
}
|
||||||
|
completed.increment();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
|
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||||
|
|
||||||
|
import io.dropwizard.core.setup.Environment;
|
||||||
|
import io.dropwizard.lifecycle.ExecutorServiceManager;
|
||||||
|
import io.dropwizard.util.Duration;
|
||||||
|
import io.micrometer.core.instrument.MeterRegistry;
|
||||||
|
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build Executor Services managed by dropwizard, supplementing executors provided by
|
||||||
|
* {@link io.dropwizard.lifecycle.setup.LifecycleEnvironment#executorService}
|
||||||
|
*/
|
||||||
|
public class ManagedExecutors {
|
||||||
|
|
||||||
|
private static final Duration SHUTDOWN_DURATION = Duration.seconds(5);
|
||||||
|
|
||||||
|
private ManagedExecutors() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ExecutorService newVirtualThreadPerTaskExecutor(
|
||||||
|
final String threadNamePrefix,
|
||||||
|
final int maxConcurrentThreads,
|
||||||
|
final Environment environment) {
|
||||||
|
|
||||||
|
final BoundedVirtualThreadFactory threadFactory =
|
||||||
|
new BoundedVirtualThreadFactory(threadNamePrefix, maxConcurrentThreads);
|
||||||
|
final ExecutorService virtualThreadExecutor = Executors.newThreadPerTaskExecutor(threadFactory);
|
||||||
|
environment.lifecycle()
|
||||||
|
.manage(new ExecutorServiceManager(virtualThreadExecutor, SHUTDOWN_DURATION, threadNamePrefix));
|
||||||
|
return virtualThreadExecutor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import org.slf4j.LoggerFactory;
|
|||||||
|
|
||||||
@ManagedAsyncExecutor
|
@ManagedAsyncExecutor
|
||||||
public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
|
public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class);
|
private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class);
|
||||||
|
|
||||||
|
|
||||||
@@ -22,17 +23,24 @@ public class VirtualExecutorServiceProvider implements ExecutorServiceProvider {
|
|||||||
* Default thread pool executor termination timeout in milliseconds.
|
* Default thread pool executor termination timeout in milliseconds.
|
||||||
*/
|
*/
|
||||||
public static final int TERMINATION_TIMEOUT = 5000;
|
public static final int TERMINATION_TIMEOUT = 5000;
|
||||||
private final String virtualThreadNamePrefix;
|
|
||||||
|
|
||||||
public VirtualExecutorServiceProvider(final String virtualThreadNamePrefix) {
|
private final String virtualThreadNamePrefix;
|
||||||
|
private final int maxConcurrentThreads;
|
||||||
|
|
||||||
|
public VirtualExecutorServiceProvider(
|
||||||
|
final String virtualThreadNamePrefix,
|
||||||
|
final int maxConcurrentThreads) {
|
||||||
this.virtualThreadNamePrefix = virtualThreadNamePrefix;
|
this.virtualThreadNamePrefix = virtualThreadNamePrefix;
|
||||||
|
this.maxConcurrentThreads = maxConcurrentThreads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExecutorService getExecutorService() {
|
public ExecutorService getExecutorService() {
|
||||||
logger.info("Creating executor service with virtual thread per task");
|
logger.info("Creating executor service with virtual thread per task");
|
||||||
return Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name(virtualThreadNamePrefix, 0).factory());
|
final ExecutorService executor = Executors.newThreadPerTaskExecutor(
|
||||||
|
new BoundedVirtualThreadFactory(virtualThreadNamePrefix, maxConcurrentThreads));
|
||||||
|
return executor;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ public class BufferingInterceptorIntegrationTest {
|
|||||||
final TestController testController = new TestController();
|
final TestController testController = new TestController();
|
||||||
environment.jersey().register(testController);
|
environment.jersey().register(testController);
|
||||||
environment.jersey().register(new BufferingInterceptor());
|
environment.jersey().register(new BufferingInterceptor());
|
||||||
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-"));
|
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
|
||||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,10 +18,17 @@ import jakarta.ws.rs.core.Response;
|
|||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ArrayBlockingQueue;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Semaphore;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import org.eclipse.jetty.util.component.LifeCycle;
|
import org.eclipse.jetty.util.component.LifeCycle;
|
||||||
import org.eclipse.jetty.websocket.api.Session;
|
import org.eclipse.jetty.websocket.api.Session;
|
||||||
import org.eclipse.jetty.websocket.api.StatusCode;
|
import org.eclipse.jetty.websocket.api.StatusCode;
|
||||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.junit.jupiter.api.AfterAll;
|
import org.junit.jupiter.api.AfterAll;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
@@ -34,6 +41,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
|
|||||||
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||||
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||||
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
||||||
|
|||||||
@@ -0,0 +1,153 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||||
|
import static org.assertj.core.api.Assertions.fail;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.RejectedExecutionException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class BoundedVirtualThreadFactoryTest {
|
||||||
|
|
||||||
|
private final static int MAX_THREADS = 10;
|
||||||
|
|
||||||
|
private BoundedVirtualThreadFactory factory;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
factory = new BoundedVirtualThreadFactory("test", MAX_THREADS);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void releaseWhenTaskThrows() throws InterruptedException {
|
||||||
|
final UncaughtExceptionHolder uncaughtExceptionHolder = new UncaughtExceptionHolder();
|
||||||
|
final Thread t = submit(() -> {
|
||||||
|
throw new IllegalArgumentException("test");
|
||||||
|
}, uncaughtExceptionHolder);
|
||||||
|
assertThat(t).isNotNull();
|
||||||
|
t.join(Duration.ofSeconds(1));
|
||||||
|
assertThat(uncaughtExceptionHolder.exception).isNotNull().isInstanceOf(IllegalArgumentException.class);
|
||||||
|
|
||||||
|
submitUntilRejected();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void releaseWhenRejected() throws InterruptedException {
|
||||||
|
submitUntilRejected();
|
||||||
|
submitUntilRejected();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void executorServiceRejectsAtLimit() throws InterruptedException {
|
||||||
|
try (final ExecutorService executor = Executors.newThreadPerTaskExecutor(factory)) {
|
||||||
|
|
||||||
|
final CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
for (int i = 0; i < MAX_THREADS; i++) {
|
||||||
|
executor.submit(() -> {
|
||||||
|
try {
|
||||||
|
latch.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
assertThatExceptionOfType(RejectedExecutionException.class).isThrownBy(() -> executor.submit(Util.NOOP));
|
||||||
|
latch.countDown();
|
||||||
|
|
||||||
|
executor.shutdown();
|
||||||
|
assertThat(executor.awaitTermination(5, TimeUnit.SECONDS)).isTrue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void stressTest() throws InterruptedException {
|
||||||
|
for (int iteration = 0; iteration < 50; iteration++) {
|
||||||
|
final CountDownLatch latch = new CountDownLatch(MAX_THREADS);
|
||||||
|
|
||||||
|
// submit a task that submits a task maxThreads/2 times
|
||||||
|
final Thread[] threads = new Thread[MAX_THREADS];
|
||||||
|
for (int i = 0; i < MAX_THREADS; i+=2) {
|
||||||
|
int outerThreadIndex = i;
|
||||||
|
int innerThreadIndex = i + 1;
|
||||||
|
|
||||||
|
threads[outerThreadIndex] = submit(() -> {
|
||||||
|
latch.countDown();
|
||||||
|
threads[innerThreadIndex] = submit(latch::countDown);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
latch.await();
|
||||||
|
|
||||||
|
// All threads must be created at this point, wait for them all to complete
|
||||||
|
for (Thread thread : threads) {
|
||||||
|
assertThat(thread).isNotNull();
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(factory.getRunningThreads()).isEqualTo(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
submitUntilRejected();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verify we can submit up to the concurrency limit (and no more)
|
||||||
|
*/
|
||||||
|
private void submitUntilRejected() throws InterruptedException {
|
||||||
|
final CountDownLatch finish = new CountDownLatch(1);
|
||||||
|
final List<Thread> threads = IntStream.range(0, MAX_THREADS).mapToObj(_ -> submit(() -> {
|
||||||
|
try {
|
||||||
|
finish.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
})).toList();
|
||||||
|
|
||||||
|
assertThat(submit(Util.NOOP)).isNull();
|
||||||
|
|
||||||
|
finish.countDown();
|
||||||
|
|
||||||
|
for (Thread thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
assertThat(factory.getRunningThreads()).isEqualTo(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Thread submit(final Runnable runnable) {
|
||||||
|
return submit(runnable, (t, e) ->
|
||||||
|
fail("Uncaught exception on thread {} : {}", t, e));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Thread submit(final Runnable runnable, final Thread.UncaughtExceptionHandler handler) {
|
||||||
|
final Thread thread = factory.newThread(runnable);
|
||||||
|
if (thread == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (handler != null) {
|
||||||
|
thread.setUncaughtExceptionHandler(handler);
|
||||||
|
}
|
||||||
|
thread.start();
|
||||||
|
return thread;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class UncaughtExceptionHolder implements Thread.UncaughtExceptionHandler {
|
||||||
|
Throwable exception = null;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void uncaughtException(final Thread t, final Throwable e) {
|
||||||
|
exception = e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -13,20 +13,37 @@ import jakarta.ws.rs.GET;
|
|||||||
import jakarta.ws.rs.Path;
|
import jakarta.ws.rs.Path;
|
||||||
import jakarta.ws.rs.core.Response;
|
import jakarta.ws.rs.core.Response;
|
||||||
import java.security.Principal;
|
import java.security.Principal;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.BlockingQueue;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.LinkedBlockingQueue;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.TimeoutException;
|
||||||
import org.glassfish.jersey.server.ManagedAsync;
|
import org.glassfish.jersey.server.ManagedAsync;
|
||||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
class VirtualExecutorServiceProviderTest {
|
class VirtualExecutorServiceProviderTest {
|
||||||
|
|
||||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
private final TestController testController = new TestController();
|
||||||
|
private final ResourceExtension resources = ResourceExtension.builder()
|
||||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
.addProvider(new VirtualExecutorServiceProvider("virtual-thread-"))
|
.addProvider(new VirtualExecutorServiceProvider( "virtual-thread-", 2))
|
||||||
.addResource(new TestController())
|
.addResource(testController)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void setUp() {
|
||||||
|
testController.release();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testManagedAsyncThread() {
|
public void testManagedAsyncThread() {
|
||||||
final Response response = resources.getJerseyTest()
|
final Response response = resources.getJerseyTest()
|
||||||
@@ -37,6 +54,22 @@ class VirtualExecutorServiceProviderTest {
|
|||||||
assertThat(threadName).startsWith("virtual-thread-");
|
assertThat(threadName).startsWith("virtual-thread-");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConcurrencyLimit() throws InterruptedException, TimeoutException {
|
||||||
|
final BlockingQueue<Response> responses = new LinkedBlockingQueue<>();
|
||||||
|
final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
executor.submit(() -> responses.offer(resources.getJerseyTest().target("/v1/test/await").request().get()));
|
||||||
|
}
|
||||||
|
final Response rejectedResponse = responses.poll(10, TimeUnit.SECONDS);
|
||||||
|
assertThat(rejectedResponse).isNotNull().extracting(Response::getStatus).isEqualTo(500);
|
||||||
|
|
||||||
|
assertThat(responses.isEmpty()).isTrue();
|
||||||
|
assertThat(testController.release()).isEqualTo(2);
|
||||||
|
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
|
||||||
|
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUnmanagedThread() {
|
public void testUnmanagedThread() {
|
||||||
final Response response = resources.getJerseyTest()
|
final Response response = resources.getJerseyTest()
|
||||||
@@ -49,6 +82,7 @@ class VirtualExecutorServiceProviderTest {
|
|||||||
|
|
||||||
@Path("/v1/test")
|
@Path("/v1/test")
|
||||||
public static class TestController {
|
public static class TestController {
|
||||||
|
private List<CountDownLatch> latches = new ArrayList<>();
|
||||||
|
|
||||||
@GET
|
@GET
|
||||||
@Path("/managed-async")
|
@Path("/managed-async")
|
||||||
@@ -57,12 +91,34 @@ class VirtualExecutorServiceProviderTest {
|
|||||||
return Response.ok().entity(Thread.currentThread().getName()).build();
|
return Response.ok().entity(Thread.currentThread().getName()).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@GET
|
||||||
|
@Path("/await")
|
||||||
|
@ManagedAsync
|
||||||
|
public Response await() throws InterruptedException {
|
||||||
|
final CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
synchronized (this) {
|
||||||
|
latches.add(latch);
|
||||||
|
}
|
||||||
|
latch.await();
|
||||||
|
return Response.ok().build();
|
||||||
|
}
|
||||||
|
|
||||||
@GET
|
@GET
|
||||||
@Path("/unmanaged")
|
@Path("/unmanaged")
|
||||||
public Response unmanaged() {
|
public Response unmanaged() {
|
||||||
return Response.ok().entity(Thread.currentThread().getName()).build();
|
return Response.ok().entity(Thread.currentThread().getName()).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
synchronized int release() {
|
||||||
|
final Iterator<CountDownLatch> iterator = latches.iterator();
|
||||||
|
int count;
|
||||||
|
for (count = 0; iterator.hasNext(); count++) {
|
||||||
|
iterator.next().countDown();
|
||||||
|
iterator.remove();
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestPrincipal implements Principal {
|
public static class TestPrincipal implements Principal {
|
||||||
|
|||||||
Reference in New Issue
Block a user