Close remote connections only after all active server calls have completed

This commit is contained in:
Jon Chambers
2025-04-11 12:30:18 -04:00
committed by Jon Chambers
parent bb8ce6d981
commit f191c68efc
6 changed files with 328 additions and 13 deletions

View File

@@ -0,0 +1,88 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
class ChannelShutdownInterceptorTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private ChannelShutdownInterceptor channelShutdownInterceptor;
private ServerCallHandler<String, String> nextCallHandler;
private static final Metadata HEADERS = new Metadata();
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
//noinspection unchecked
nextCallHandler = mock(ServerCallHandler.class);
//noinspection unchecked
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
}
@Test
void interceptCallComplete() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onComplete();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallCancelled() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onCancel();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallChannelClosing() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
}
}

View File

@@ -12,14 +12,38 @@ import org.signal.chat.rpc.EchoServiceGrpc;
public class EchoServiceImpl extends EchoServiceGrpc.EchoServiceImplBase {
@Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
public void echo(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted();
}
@Override
public void echo2(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(EchoResponse.newBuilder().setPayload(req.getPayload()).build());
public void echo2(final EchoRequest echoRequest, final StreamObserver<EchoResponse> responseObserver) {
responseObserver.onNext(buildResponse(echoRequest));
responseObserver.onCompleted();
}
@Override
public StreamObserver<EchoRequest> echoStream(final StreamObserver<EchoResponse> responseObserver) {
return new StreamObserver<>() {
@Override
public void onNext(final EchoRequest echoRequest) {
responseObserver.onNext(buildResponse(echoRequest));
}
@Override
public void onError(final Throwable throwable) {
responseObserver.onError(throwable);
}
@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}
private static EchoResponse buildResponse(final EchoRequest echoRequest) {
return EchoResponse.newBuilder().setPayload(echoRequest.getPayload()).build();
}
}

View File

@@ -1,6 +1,7 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
@@ -8,10 +9,12 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
@@ -61,6 +64,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
@@ -71,6 +77,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
@@ -83,6 +91,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private static X509Certificate serverTlsCertificate;
@@ -136,7 +145,8 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newSingleThreadExecutor();
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
@@ -171,7 +181,11 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl())
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
}
@@ -182,7 +196,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new RequestAttributesServiceImpl())
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
}
@@ -195,7 +211,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
serverTlsPrivateKey,
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
@@ -209,7 +225,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null,
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
@@ -235,6 +251,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@ParameterizedTest
@@ -579,6 +599,89 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
@Test
void waitForCallCompletion() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should not close until active requests have finished");
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should close once active requests have finished");
assertTrue(closedByServer.get());
assertEquals(4004, serverCloseStatusCode.get());
} finally {
channel.shutdown();
}
}
}
private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())