mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-20 21:38:06 +01:00
Add a bounded virtual executor service
This commit is contained in:
@@ -34,7 +34,7 @@ public class BufferingInterceptorIntegrationTest {
|
||||
final TestController testController = new TestController();
|
||||
environment.jersey().register(testController);
|
||||
environment.jersey().register(new BufferingInterceptor());
|
||||
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-"));
|
||||
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
|
||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,10 +18,17 @@ import jakarta.ws.rs.core.Response;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ArrayBlockingQueue;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Semaphore;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.eclipse.jetty.util.component.LifeCycle;
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.StatusCode;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -34,6 +41,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.assertj.core.api.Assertions.fail;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.RejectedExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.IntStream;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class BoundedVirtualThreadFactoryTest {
|
||||
|
||||
private final static int MAX_THREADS = 10;
|
||||
|
||||
private BoundedVirtualThreadFactory factory;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
factory = new BoundedVirtualThreadFactory("test", MAX_THREADS);
|
||||
}
|
||||
|
||||
@Test
|
||||
void releaseWhenTaskThrows() throws InterruptedException {
|
||||
final UncaughtExceptionHolder uncaughtExceptionHolder = new UncaughtExceptionHolder();
|
||||
final Thread t = submit(() -> {
|
||||
throw new IllegalArgumentException("test");
|
||||
}, uncaughtExceptionHolder);
|
||||
assertThat(t).isNotNull();
|
||||
t.join(Duration.ofSeconds(1));
|
||||
assertThat(uncaughtExceptionHolder.exception).isNotNull().isInstanceOf(IllegalArgumentException.class);
|
||||
|
||||
submitUntilRejected();
|
||||
}
|
||||
|
||||
@Test
|
||||
void releaseWhenRejected() throws InterruptedException {
|
||||
submitUntilRejected();
|
||||
submitUntilRejected();
|
||||
}
|
||||
|
||||
@Test
|
||||
void executorServiceRejectsAtLimit() throws InterruptedException {
|
||||
try (final ExecutorService executor = Executors.newThreadPerTaskExecutor(factory)) {
|
||||
|
||||
final CountDownLatch latch = new CountDownLatch(1);
|
||||
for (int i = 0; i < MAX_THREADS; i++) {
|
||||
executor.submit(() -> {
|
||||
try {
|
||||
latch.await();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
}
|
||||
assertThatExceptionOfType(RejectedExecutionException.class).isThrownBy(() -> executor.submit(Util.NOOP));
|
||||
latch.countDown();
|
||||
|
||||
executor.shutdown();
|
||||
assertThat(executor.awaitTermination(5, TimeUnit.SECONDS)).isTrue();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void stressTest() throws InterruptedException {
|
||||
for (int iteration = 0; iteration < 50; iteration++) {
|
||||
final CountDownLatch latch = new CountDownLatch(MAX_THREADS);
|
||||
|
||||
// submit a task that submits a task maxThreads/2 times
|
||||
final Thread[] threads = new Thread[MAX_THREADS];
|
||||
for (int i = 0; i < MAX_THREADS; i+=2) {
|
||||
int outerThreadIndex = i;
|
||||
int innerThreadIndex = i + 1;
|
||||
|
||||
threads[outerThreadIndex] = submit(() -> {
|
||||
latch.countDown();
|
||||
threads[innerThreadIndex] = submit(latch::countDown);
|
||||
});
|
||||
}
|
||||
latch.await();
|
||||
|
||||
// All threads must be created at this point, wait for them all to complete
|
||||
for (Thread thread : threads) {
|
||||
assertThat(thread).isNotNull();
|
||||
thread.join();
|
||||
}
|
||||
|
||||
assertThat(factory.getRunningThreads()).isEqualTo(0);
|
||||
}
|
||||
|
||||
submitUntilRejected();
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify we can submit up to the concurrency limit (and no more)
|
||||
*/
|
||||
private void submitUntilRejected() throws InterruptedException {
|
||||
final CountDownLatch finish = new CountDownLatch(1);
|
||||
final List<Thread> threads = IntStream.range(0, MAX_THREADS).mapToObj(_ -> submit(() -> {
|
||||
try {
|
||||
finish.await();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
})).toList();
|
||||
|
||||
assertThat(submit(Util.NOOP)).isNull();
|
||||
|
||||
finish.countDown();
|
||||
|
||||
for (Thread thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
assertThat(factory.getRunningThreads()).isEqualTo(0);
|
||||
}
|
||||
|
||||
private Thread submit(final Runnable runnable) {
|
||||
return submit(runnable, (t, e) ->
|
||||
fail("Uncaught exception on thread {} : {}", t, e));
|
||||
}
|
||||
|
||||
private Thread submit(final Runnable runnable, final Thread.UncaughtExceptionHandler handler) {
|
||||
final Thread thread = factory.newThread(runnable);
|
||||
if (thread == null) {
|
||||
return null;
|
||||
}
|
||||
if (handler != null) {
|
||||
thread.setUncaughtExceptionHandler(handler);
|
||||
}
|
||||
thread.start();
|
||||
return thread;
|
||||
}
|
||||
|
||||
private static class UncaughtExceptionHolder implements Thread.UncaughtExceptionHandler {
|
||||
Throwable exception = null;
|
||||
|
||||
@Override
|
||||
public void uncaughtException(final Thread t, final Throwable e) {
|
||||
exception = e;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -13,20 +13,37 @@ import jakarta.ws.rs.GET;
|
||||
import jakarta.ws.rs.Path;
|
||||
import jakarta.ws.rs.core.Response;
|
||||
import java.security.Principal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class VirtualExecutorServiceProviderTest {
|
||||
|
||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||
private final TestController testController = new TestController();
|
||||
private final ResourceExtension resources = ResourceExtension.builder()
|
||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||
.addProvider(new VirtualExecutorServiceProvider("virtual-thread-"))
|
||||
.addResource(new TestController())
|
||||
.addProvider(new VirtualExecutorServiceProvider( "virtual-thread-", 2))
|
||||
.addResource(testController)
|
||||
.build();
|
||||
|
||||
@AfterEach
|
||||
void setUp() {
|
||||
testController.release();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testManagedAsyncThread() {
|
||||
final Response response = resources.getJerseyTest()
|
||||
@@ -37,6 +54,22 @@ class VirtualExecutorServiceProviderTest {
|
||||
assertThat(threadName).startsWith("virtual-thread-");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConcurrencyLimit() throws InterruptedException, TimeoutException {
|
||||
final BlockingQueue<Response> responses = new LinkedBlockingQueue<>();
|
||||
final ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
executor.submit(() -> responses.offer(resources.getJerseyTest().target("/v1/test/await").request().get()));
|
||||
}
|
||||
final Response rejectedResponse = responses.poll(10, TimeUnit.SECONDS);
|
||||
assertThat(rejectedResponse).isNotNull().extracting(Response::getStatus).isEqualTo(500);
|
||||
|
||||
assertThat(responses.isEmpty()).isTrue();
|
||||
assertThat(testController.release()).isEqualTo(2);
|
||||
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
|
||||
assertThat(responses.poll(1, TimeUnit.SECONDS)).isNotNull().extracting(Response::getStatus).isEqualTo(200);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnmanagedThread() {
|
||||
final Response response = resources.getJerseyTest()
|
||||
@@ -49,6 +82,7 @@ class VirtualExecutorServiceProviderTest {
|
||||
|
||||
@Path("/v1/test")
|
||||
public static class TestController {
|
||||
private List<CountDownLatch> latches = new ArrayList<>();
|
||||
|
||||
@GET
|
||||
@Path("/managed-async")
|
||||
@@ -57,12 +91,34 @@ class VirtualExecutorServiceProviderTest {
|
||||
return Response.ok().entity(Thread.currentThread().getName()).build();
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/await")
|
||||
@ManagedAsync
|
||||
public Response await() throws InterruptedException {
|
||||
final CountDownLatch latch = new CountDownLatch(1);
|
||||
synchronized (this) {
|
||||
latches.add(latch);
|
||||
}
|
||||
latch.await();
|
||||
return Response.ok().build();
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/unmanaged")
|
||||
public Response unmanaged() {
|
||||
return Response.ok().entity(Thread.currentThread().getName()).build();
|
||||
}
|
||||
|
||||
synchronized int release() {
|
||||
final Iterator<CountDownLatch> iterator = latches.iterator();
|
||||
int count;
|
||||
for (count = 0; iterator.hasNext(); count++) {
|
||||
iterator.next().countDown();
|
||||
iterator.remove();
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class TestPrincipal implements Principal {
|
||||
|
||||
Reference in New Issue
Block a user