Add NoiseDirect framing protocol

This commit is contained in:
ravi-signal
2025-04-30 15:05:05 -05:00
committed by GitHub
parent e285bf1a52
commit 0398e02690
63 changed files with 2111 additions and 1450 deletions

View File

@@ -4,7 +4,7 @@ import io.netty.util.ResourceLeakDetector;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
abstract class AbstractLeakDetectionTest {
public abstract class AbstractLeakDetectionTest {
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;

View File

@@ -119,15 +119,15 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
* waiting messages in the channel, return null.
*/
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
final BinaryWebSocketFrame responseFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
if (responseFrame == null) {
return null;
}
final byte[] plaintext = new byte[responseFrame.content().readableBytes() - 16];
final byte[] plaintext = new byte[responseFrame.readableBytes() - 16];
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
ByteBufUtil.getBytes(responseFrame.content()), 0,
ByteBufUtil.getBytes(responseFrame), 0,
plaintext, 0,
responseFrame.content().readableBytes());
responseFrame.readableBytes());
assertEquals(read, plaintext.length);
return plaintext;
}
@@ -140,7 +140,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(content).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
@@ -150,18 +150,18 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
@Test
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
final ByteBuf[] frames = new ByteBuf[7];
for (int i = 0; i < frames.length; i++) {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
frames[i] = Unpooled.wrappedBuffer(contentBytes);
embeddedChannel.writeOneInbound(frames[i]).await();
}
for (final BinaryWebSocketFrame frame : frames) {
for (final ByteBuf frame : frames) {
assertEquals(0, frame.refCnt());
}
@@ -169,11 +169,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
}
@Test
void handleNonWebSocketBinaryFrame() throws Throwable {
void handleNonByteBufBinaryFrame() throws Throwable {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
@@ -192,7 +192,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext);
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
assertEquals(0, ciphertextFrame.refCnt());
@@ -206,7 +206,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
final byte[] bogusCiphertext = new byte[32];
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext);
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
assertEquals(0, ciphertextFrame.refCnt());
@@ -235,11 +235,11 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
assertTrue(writePlaintextFuture.await().isSuccess());
assertEquals(0, plaintextBuffer.refCnt());
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
assertNotNull(ciphertextFrame);
assertTrue(embeddedChannel.outboundMessages().isEmpty());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
ciphertextFrame.release();
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
@@ -272,10 +272,10 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
final byte[] decryptedPlaintext = new byte[plaintextLength];
int plaintextOffset = 0;
BinaryWebSocketFrame ciphertextFrame;
while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) {
assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN);
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
ByteBuf ciphertextFrame;
while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) {
assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN);
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
ciphertextFrame.release();
plaintextOffset += clientCipherStatePair.getReceiver()
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
@@ -289,7 +289,7 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
public void writeHugeInboundMessage() throws Throwable {
doHandshake();
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
embeddedChannel.pipeline().fireChannelRead(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(big)));
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
assertThrows(NoiseException.class, embeddedChannel::checkException);
}
}

View File

@@ -0,0 +1,426 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private GrpcClientConnectionManager grpcClientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair serverKeyPair;
private ECKeyPair clientKeyPair;
private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
@BeforeAll
static void setUpBeforeAll() {
nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
}
@BeforeEach
void setUp() throws Exception {
clientKeyPair = Curve.generateKeyPair();
serverKeyPair = Curve.generateKeyPair();
grpcClientConnectionManager = new GrpcClientConnectionManager();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
}
};
authenticatedGrpcServer.start();
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
}
};
anonymousGrpcServer.start();
this.start(
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
}
protected abstract void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception;
protected abstract void stop() throws Exception;
protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey);
public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason)
throws ExecutionException, InterruptedException, TimeoutException {
final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
assertEquals(reason, result.closeReason());
}
@AfterEach
void tearDown() throws Exception {
authenticatedGrpcServer.stop();
anonymousGrpcServer.stop();
this.stop();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseTunnelClient client = authenticated()
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseTunnelClient client = authenticated()
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException {
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
}
}
@Test
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException {
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.AUTHENTICATION_ERROR);
}
}
@Test
void connectAnonymous() throws InterruptedException {
try (final NoiseTunnelClient client = anonymous().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseTunnelClient client = anonymous()
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
return NettyChannelBuilder.forAddress(localAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(defaultEventLoopGroup)
.usePlaintext()
.build();
}
@Test
void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS);
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason());
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator());
} finally {
channel.shutdown();
}
}
}
@Test
void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
try {
client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS);
fail("Channel should not close until active requests have finished");
} catch (TimeoutException e) {
}
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator());
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason());
} finally {
channel.shutdown();
}
}
}
protected NoiseTunnelClient.Builder anonymous() {
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey());
}
protected NoiseTunnelClient.Builder authenticated() {
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID);
}
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
return includeProxyMesage
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443)
: null;
}
}

View File

@@ -1,18 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
class ClientErrorHandler extends ErrorHandler {
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
setWebsocketHandshakeComplete();
}
}
super.userEventTriggered(context, event);
}
}

View File

@@ -1,198 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.Noise;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
/**
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
* gRPC server
*/
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final boolean useTls;
@Nullable private final X509Certificate trustedServerCertificate;
private final URI websocketUri;
private final boolean authenticated;
@Nullable private final ECKeyPair ecKeyPair;
private final ECPublicKey serverPublicKey;
@Nullable private final UUID accountIdentifier;
private final byte deviceId;
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
// If provided, will be sent with the payload in the noise handshake
private final byte[] fastOpenRequest;
private final List<Object> pendingReads = new ArrayList<>();
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler(
final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
final URI websocketUri,
final boolean authenticated,
@Nullable final ECKeyPair ecKeyPair,
final ECPublicKey serverPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener,
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier,
@Nullable byte[] fastOpenRequest) {
this.useTls = useTls;
this.trustedServerCertificate = trustedServerCertificate;
this.websocketUri = websocketUri;
this.authenticated = authenticated;
this.ecKeyPair = ecKeyPair;
this.serverPublicKey = serverPublicKey;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.headers = headers;
this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener;
this.proxyMessageSupplier = proxyMessageSupplier;
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
}
@Override
public void handlerAdded(final ChannelHandlerContext localContext) {
new Bootstrap()
.channel(NioSocketChannel.class)
.group(localContext.channel().eventLoop())
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws SSLException {
if (proxyMessageSupplier != null) {
// In a production setting, we'd want some mechanism to remove these handlers after the initial message
// were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a
// pair of inert handlers in the pipeline.
channel.pipeline()
.addLast(HAProxyMessageEncoder.INSTANCE)
.addLast(new HAProxyMessageSender(proxyMessageSupplier));
}
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (trustedServerCertificate != null) {
sslContextBuilder.trustManager(trustedServerCertificate);
}
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
}
final NoiseClientHandshakeHelper helper = authenticated
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
: NoiseClientHandshakeHelper.NK(serverPublicKey);
channel.pipeline()
.addLast(new HttpClientCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
// want to react to them on our own, we need to catch them before they hit that handler.
.addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener))
.addLast(new WebSocketClientProtocolHandler(websocketUri,
WebSocketVersion.V13,
null,
false,
headers,
Noise.MAX_PACKET_LEN,
10_000))
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
// Listens for a Websocket HANDSHAKE_COMPLETE and begins the noise handshake when it is done
.addLast(new NoiseClientHandshakeHandler(helper, initialPayload()))
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
throws Exception {
if (event instanceof NoiseClientHandshakeCompleteEvent handshakeCompleteEvent) {
remoteContext.pipeline()
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
// If there was a payload response on the handshake, write it back to our gRPC client
handshakeCompleteEvent.fastResponse().ifPresent(plaintext ->
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
pendingReads.forEach(localContext::fireChannelRead);
pendingReads.clear();
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
}
super.userEventTriggered(remoteContext, event);
}
})
.addLast(new ClientErrorHandler());
}
})
.connect(remoteServerAddress)
.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
// Close the local connection if the remote channel closes and vice versa
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
} else {
localContext.close();
}
});
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
pendingReads.add(message);
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
private byte[] initialPayload() {
if (!authenticated) {
return fastOpenRequest;
}
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
bb.putLong(accountIdentifier.getMostSignificantBits());
bb.putLong(accountIdentifier.getLeastSignificantBits());
bb.put(deviceId);
bb.put(fastOpenRequest);
bb.flip();
return bb.array();
}
}

View File

@@ -1,9 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}

View File

@@ -146,14 +146,14 @@ class GrpcClientConnectionManagerTest {
@ParameterizedTest
@MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel,
preferredRemoteAddress,
userAgentHeader,
acceptLanguageHeader);
@@ -162,7 +162,7 @@ class GrpcClientConnectionManagerTest {
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
}
private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
private static List<Arguments> handleHandshakeInitiatedRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of(

View File

@@ -1,23 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter {
private final WebSocketCloseListener webSocketCloseListener;
public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode());
}
super.channelRead(context, message);
}
}

View File

@@ -10,9 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
@@ -49,22 +49,18 @@ class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
assertEquals(
initiateHandshakeMessageLength,
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
final BinaryWebSocketFrame initiateHandshakeFrame = new BinaryWebSocketFrame(
Unpooled.wrappedBuffer(initiateHandshakeMessage));
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeFrame).await().isSuccess());
assertEquals(0, initiateHandshakeFrame.refCnt());
final ByteBuf initiateHandshakeMessageBuf = Unpooled.wrappedBuffer(initiateHandshakeMessage);
assertTrue(embeddedChannel.writeOneInbound(initiateHandshakeMessageBuf).await().isSuccess());
assertEquals(0, initiateHandshakeMessageBuf.refCnt());
embeddedChannel.runPendingTasks();
// Read responder handshake message
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final BinaryWebSocketFrame responderHandshakeFrame = (BinaryWebSocketFrame)
embeddedChannel.outboundMessages().poll();
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] responderHandshakeBytes =
new byte[responderHandshakeFrame.content().readableBytes()];
responderHandshakeFrame.content().readBytes(responderHandshakeBytes);
new byte[responderHandshakeFrame.readableBytes()];
responderHandshakeFrame.readBytes(responderHandshakeBytes);
// ephemeral key, empty encrypted payload AEAD tag
final byte[] handshakeResponsePayload = new byte[32 + 16];

View File

@@ -15,10 +15,10 @@ import static org.mockito.Mockito.when;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.internal.EmptyArrays;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
@@ -34,6 +34,7 @@ import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@@ -204,13 +205,12 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)));
assertTrue(embeddedChannel.writeOneInbound(initiatorMessageFrame).await().isSuccess());
// While waiting for the public key, send another message
final ChannelFuture f = embeddedChannel.writeOneInbound(
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[0]))).await();
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
assertInstanceOf(NoiseHandshakeException.class, f.exceptionNow());
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
@@ -267,8 +267,7 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
final HandshakeState clientHandshakeState = clientHandshakeState();
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
final BinaryWebSocketFrame initiatorMessageFrame = new BinaryWebSocketFrame(
Unpooled.wrappedBuffer(initiatorMessage));
final ByteBuf initiatorMessageFrame = Unpooled.wrappedBuffer(initiatorMessage);
final ChannelFuture await = embeddedChannel.writeOneInbound(initiatorMessageFrame).await();
assertEquals(0, initiatorMessageFrame.refCnt());
if (!await.isSuccess()) {
@@ -286,11 +285,10 @@ class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
final ByteBuf serverStaticKeyMessageFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
new byte[serverStaticKeyMessageFrame.readableBytes()];
serverStaticKeyMessageFrame.readBytes(serverStaticKeyMessageBytes);
assertEquals(readHandshakeResponse(clientHandshakeState, serverStaticKeyMessageBytes).length, 0);

View File

@@ -1,55 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import java.util.Optional;
class NoiseClientHandshakeHandler extends ChannelInboundHandlerAdapter {
private final NoiseClientHandshakeHelper handshakeHelper;
private final byte[] payload;
NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper, final byte[] payload) {
this.handshakeHelper = handshakeHelper;
this.payload = payload;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
byte[] handshakeMessage = handshakeHelper.write(payload);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
}
}
super.userEventTriggered(context, event);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException {
if (message instanceof BinaryWebSocketFrame frame) {
try {
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame.content()));
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
handshakeHelper.destroy();
}
}

View File

@@ -16,6 +16,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper;
public class NoiseHandshakeHelperTest {

View File

@@ -1,160 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Supplier;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class NoiseWebSocketTunnelClient implements AutoCloseable {
private final ServerBootstrap serverBootstrap;
private Channel serverChannel;
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
static class Builder {
final SocketAddress remoteServerAddress;
NioEventLoopGroup eventLoopGroup;
ECPublicKey serverPublicKey;
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
HttpHeaders headers = new DefaultHttpHeaders();
WebSocketCloseListener webSocketCloseListener = WebSocketCloseListener.NOOP_LISTENER;
boolean authenticated = false;
ECKeyPair ecKeyPair = null;
UUID accountIdentifier = null;
byte deviceId = 0x00;
boolean useTls;
X509Certificate trustedServerCertificate = null;
Supplier<HAProxyMessage> proxyMessageSupplier = null;
Builder(
final SocketAddress remoteServerAddress,
final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
this.remoteServerAddress = remoteServerAddress;
this.eventLoopGroup = eventLoopGroup;
this.serverPublicKey = serverPublicKey;
}
Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
this.authenticated = true;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.ecKeyPair = ecKeyPair;
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
return this;
}
Builder setWebsocketUri(final URI websocketUri) {
this.websocketUri = websocketUri;
return this;
}
Builder setUseTls(X509Certificate trustedServerCertificate) {
this.useTls = true;
this.trustedServerCertificate = trustedServerCertificate;
return this;
}
Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
this.proxyMessageSupplier = proxyMessageSupplier;
return this;
}
Builder setHeaders(final HttpHeaders headers) {
this.headers = headers;
return this;
}
Builder setWebSocketCloseListener(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
return this;
}
Builder setServerPublicKey(ECPublicKey serverPublicKey) {
this.serverPublicKey = serverPublicKey;
return this;
}
NoiseWebSocketTunnelClient build() {
final NoiseWebSocketTunnelClient client =
new NoiseWebSocketTunnelClient(eventLoopGroup, fastOpenRequest -> new EstablishRemoteConnectionHandler(
useTls, trustedServerCertificate, websocketUri, authenticated, ecKeyPair, serverPublicKey,
accountIdentifier, deviceId, headers, remoteServerAddress, webSocketCloseListener, proxyMessageSupplier,
fastOpenRequest));
client.start();
return client;
}
}
private NoiseWebSocketTunnelClient(NioEventLoopGroup eventLoopGroup,
Function<byte[], EstablishRemoteConnectionHandler> handler) {
this.serverBootstrap = new ServerBootstrap()
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
.channel(LocalServerChannel.class)
.group(eventLoopGroup)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline()
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
// in the handshake.
.addLast(Http2Buffering.handler())
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
// connect to the remote service
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
requestBufferedEvent.fastOpenRequest().release();
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
}
super.userEventTriggered(ctx, evt);
}
})
.addLast(new ClientErrorHandler());
}
});
}
LocalAddress getLocalAddress() {
return (LocalAddress) serverChannel.localAddress();
}
private NoiseWebSocketTunnelClient start() {
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
return this;
}
@Override
public void close() throws InterruptedException {
serverChannel.close().await();
}
}

View File

@@ -1,705 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private static X509Certificate serverTlsCertificate;
private GrpcClientConnectionManager grpcClientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair serverKeyPair;
private ECKeyPair clientKeyPair;
private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer;
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with:
//
// ```shell
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
// ```
private static final String SERVER_CERTIFICATE = """
-----BEGIN CERTIFICATE-----
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
iOr9sHiO8Rn2u0xRKgU5Ig==
-----END CERTIFICATE-----
""";
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
private static final String SERVER_PRIVATE_KEY = """
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
""";
@BeforeAll
static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
}
@BeforeEach
void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException {
final PrivateKey serverTlsPrivateKey;
{
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
serverTlsPrivateKey =
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
}
clientKeyPair = Curve.generateKeyPair();
serverKeyPair = Curve.generateKeyPair();
grpcClientConnectionManager = new GrpcClientConnectionManager();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
}
};
authenticatedGrpcServer.start();
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
}
};
anonymousGrpcServer.start();
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
new X509Certificate[]{serverTlsCertificate},
serverTlsPrivateKey,
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
tlsNoiseWebSocketTunnelServer.start();
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
null,
null,
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
plaintextNoiseWebSocketTunnelServer.start();
}
@AfterEach
void tearDown() throws InterruptedException {
tlsNoiseWebSocketTunnelServer.stop();
plaintextNoiseWebSocketTunnelServer.stop();
authenticatedGrpcServer.stop();
anonymousGrpcServer.stop();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = authenticated()
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedToAnonymousService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebsocketUri(NoiseWebSocketTunnelClient.ANONYMOUS_WEBSOCKET_URI)
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymous() throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = anonymous().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAnonymousBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseWebSocketTunnelClient client = anonymous()
.setWebSocketCloseListener(webSocketCloseListener)
.setServerPublicKey(Curve.generateKeyPair().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymousToAuthenticatedService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final NoiseWebSocketTunnelClient client = anonymous()
.setWebsocketUri(NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(
ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
return NettyChannelBuilder.forAddress(localAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(defaultEventLoopGroup)
.usePlaintext()
.build();
}
@Test
void rejectIllegalRequests() throws Exception {
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
final TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
final URI incorrectUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
.uri(authenticatedUri)
.PUT(HttpRequest.BodyPublishers.ofString("test"))
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"Non-GET requests should not be allowed");
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(authenticatedUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests without upgrade headers should not be allowed");
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(incorrectUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests to unrecognized URIs should not be allowed");
}
}
@Test
void getRequestAttributes() throws InterruptedException {
final String remoteAddress = "4.5.6.7";
final String acceptLanguage = "en";
final String userAgent = "Signal-Desktop/1.2.3 Linux";
final HttpHeaders headers = new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add("X-Forwarded-For", remoteAddress)
.add("Accept-Language", acceptLanguage)
.add("User-Agent", userAgent);
try (final NoiseWebSocketTunnelClient client = anonymous().setHeaders(headers).build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
} finally {
channel.shutdown();
}
}
}
@Test
void closeForReauthentication() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS));
assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(),
serverCloseStatusCode.get());
assertTrue(closedByServer.get());
} finally {
channel.shutdown();
}
}
}
@Test
void waitForCallCompletion() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final NoiseWebSocketTunnelClient client = authenticated()
.setWebSocketCloseListener(webSocketCloseListener)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertFalse(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should not close until active requests have finished");
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
assertTrue(connectionCloseLatch.await(1, TimeUnit.SECONDS),
"Channel should close once active requests have finished");
assertTrue(closedByServer.get());
assertEquals(4004, serverCloseStatusCode.get());
} finally {
channel.shutdown();
}
}
}
private NoiseWebSocketTunnelClient.Builder anonymous() {
return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setUseTls(serverTlsCertificate);
}
private NoiseWebSocketTunnelClient.Builder authenticated() {
return new NoiseWebSocketTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID)
.setUseTls(serverTlsCertificate);
}
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
return includeProxyMesage
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443)
: null;
}
}

View File

@@ -1,24 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter {
private final WebSocketCloseListener webSocketCloseListener;
OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode());
}
super.write(context, message, promise);
}
}

View File

@@ -1,80 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A TypedNoiseChannelDuplexHandler is a convenience {@link ChannelDuplexHandler} that can be inserted in a pipeline
* after a successful websocket handshake. It expects inbound messages to be {@link BinaryWebSocketFrame}s and outbound
* messages to be bytes.
*/
abstract class TypedNoiseChannelDuplexHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(TypedNoiseChannelDuplexHandler.class);
/**
* Handle an inbound message. The frame will be automatically released after the method is finished running.
*
* @param context The current {@link ChannelHandlerContext}
* @param frameBytes A {@link ByteBuf} extracted from a {@link BinaryWebSocketFrame} that contains a complete noise
* packet
* @throws Exception
*/
abstract void handleInbound(final ChannelHandlerContext context, ByteBuf frameBytes) throws Exception;
/**
* Handle an outbound byte message. The message will be automatically released after the method is finished running.
*
* @param context The current {@link ChannelHandlerContext}
* @param bytes The bytes to write
* @throws Exception
*/
abstract void handleOutbound(final ChannelHandlerContext context, final ByteBuf bytes,
final ChannelPromise promise) throws Exception;
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
handleInbound(context, frame.content());
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} finally {
ReferenceCountUtil.release(message);
}
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf serverResponse) {
try {
handleOutbound(context, serverResponse, promise);
} finally {
ReferenceCountUtil.release(serverResponse);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
}

View File

@@ -1,18 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
interface WebSocketCloseListener {
WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
}
};
void handleWebSocketClosedByClient(int statusCode);
void handleWebSocketClosedByServer(int statusCode);
}

View File

@@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ClientErrorHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class);
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
log.error("Caught inbound error in client; closing connection", cause);
context.channel().close();
}
}

View File

@@ -0,0 +1,53 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
public enum CloseReason {
SERVER_CLOSED,
NOISE_ERROR,
NOISE_HANDSHAKE_ERROR,
AUTHENTICATION_ERROR,
INTERNAL_SERVER_ERROR,
UNKNOWN
}
public enum CloseInitiator {
SERVER,
CLIENT
}
public static CloseFrameEvent fromWebsocketCloseFrame(
CloseWebSocketFrame closeWebSocketFrame,
CloseInitiator closeInitiator) {
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
case 4003 -> CloseReason.NOISE_ERROR;
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
case 4002 -> CloseReason.AUTHENTICATION_ERROR;
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
case 1012 -> CloseReason.SERVER_CLOSED;
default -> CloseReason.UNKNOWN;
};
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
}
public static CloseFrameEvent fromNoiseDirectErrorFrame(
NoiseDirectProtos.Error noiseDirectError,
CloseInitiator closeInitiator) {
final CloseReason code = switch (noiseDirectError.getType()) {
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
case AUTHENTICATION_ERROR -> CloseReason.AUTHENTICATION_ERROR;
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
};
return new CloseFrameEvent(code, closeInitiator, noiseDirectError.getMessage());
}
}

View File

@@ -0,0 +1,136 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
/**
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
* gRPC server.
* <p>
* This handler waits until the first gRPC client message is ready and then establishes a connection with the remote
* gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote
* connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the
* handshake is finished.
*/
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final List<ChannelHandler> remoteHandlerStack;
@Nullable
private final AuthenticatedDevice authenticatedDevice;
private final SocketAddress remoteServerAddress;
// If provided, will be sent with the payload in the noise handshake
private final byte[] fastOpenRequest;
private final List<Object> pendingReads = new ArrayList<>();
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler(
final List<ChannelHandler> remoteHandlerStack,
@Nullable final AuthenticatedDevice authenticatedDevice,
final SocketAddress remoteServerAddress,
@Nullable byte[] fastOpenRequest) {
this.remoteHandlerStack = remoteHandlerStack;
this.authenticatedDevice = authenticatedDevice;
this.remoteServerAddress = remoteServerAddress;
this.fastOpenRequest = fastOpenRequest == null ? new byte[0] : fastOpenRequest;
}
@Override
public void handlerAdded(final ChannelHandlerContext localContext) {
new Bootstrap()
.channel(NioSocketChannel.class)
.group(localContext.channel().eventLoop())
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws Exception {
for (ChannelHandler handler : remoteHandlerStack) {
channel.pipeline().addLast(handler);
}
channel.pipeline()
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
throws Exception {
switch (event) {
case ReadyForNoiseHandshakeEvent ignored ->
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(initialPayload()))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
case NoiseClientHandshakeCompleteEvent(Optional<byte[]> fastResponse) -> {
remoteContext.pipeline()
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
// If there was a payload response on the handshake, write it back to our gRPC client
fastResponse.ifPresent(plaintext ->
localContext.writeAndFlush(Unpooled.wrappedBuffer(plaintext)));
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
pendingReads.forEach(localContext::fireChannelRead);
pendingReads.clear();
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
}
default -> {
}
}
super.userEventTriggered(remoteContext, event);
}
})
.addLast(new ClientErrorHandler());
}
})
.connect(remoteServerAddress)
.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
// Close the local connection if the remote channel closes and vice versa
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
} else {
localContext.close();
}
});
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
pendingReads.add(message);
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
private byte[] initialPayload() {
if (authenticatedDevice == null) {
return fastOpenRequest;
}
final ByteBuffer bb = ByteBuffer.allocate(17 + fastOpenRequest.length);
bb.putLong(authenticatedDevice.accountIdentifier().getMostSignificantBits());
bb.putLong(authenticatedDevice.accountIdentifier().getLeastSignificantBits());
bb.put(authenticatedDevice.deviceId());
bb.put(fastOpenRequest);
bb.flip();
return bb.array();
}
}

View File

@@ -0,0 +1,9 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.buffer.ByteBuf;
public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

View File

@@ -2,7 +2,7 @@
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
@@ -12,6 +12,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HexFormat;
@@ -27,12 +28,12 @@ import java.util.stream.Stream;
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
* {@link FastOpenRequestBufferedEvent}
*/
class Http2Buffering {
public class Http2Buffering {
/**
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
*/
static ChannelInboundHandler handler() {
public static ChannelInboundHandler handler() {
return new Http2PrefaceHandler();
}

View File

@@ -2,7 +2,7 @@
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.client;
import java.util.Optional;

View File

@@ -0,0 +1,56 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import java.util.Optional;
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
private final NoiseClientHandshakeHelper handshakeHelper;
public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) {
this.handshakeHelper = handshakeHelper;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof ByteBuf plaintextHandshakePayload) {
final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload,
plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(),
false);
final byte[] handshakeMessage = handshakeHelper.write(payloadBytes);
ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise);
} else {
ctx.write(msg, promise);
}
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException {
if (message instanceof ByteBuf frame) {
try {
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
final Optional<byte[]> fastResponse = Optional.ofNullable(payload.length == 0 ? null : payload);
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(fastResponse));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
handshakeHelper.destroy();
}
}

View File

@@ -2,7 +2,7 @@
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
@@ -11,6 +11,8 @@ import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
public class NoiseClientHandshakeHelper {
@@ -22,7 +24,7 @@ public class NoiseClientHandshakeHelper {
this.handshakeState = handshakeState;
}
static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
@@ -34,7 +36,7 @@ public class NoiseClientHandshakeHelper {
}
}
static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
@@ -45,7 +47,7 @@ public class NoiseClientHandshakeHelper {
}
}
byte[] write(final byte[] requestPayload) throws ShortBufferException {
public byte[] write(final byte[] requestPayload) throws ShortBufferException {
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
return initiateHandshakeMessage;
@@ -60,7 +62,7 @@ public class NoiseClientHandshakeHelper {
};
}
byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
// Don't process additional messages if the handshake failed and we're just waiting to close
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
@@ -83,11 +85,11 @@ public class NoiseClientHandshakeHelper {
}
}
CipherStatePair split() {
public CipherStatePair split() {
return this.handshakeState.split();
}
void destroy() {
public void destroy() {
this.handshakeState.destroy();
}
}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.CipherStatePair;
@@ -8,8 +8,6 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -17,7 +15,7 @@ import org.slf4j.LoggerFactory;
/**
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
*/
class NoiseClientTransportHandler extends ChannelDuplexHandler {
public class NoiseClientTransportHandler extends ChannelDuplexHandler {
private final CipherStatePair cipherStatePair;
@@ -30,19 +28,19 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
if (message instanceof ByteBuf frame) {
final CipherState cipherState = cipherStatePair.getReceiver();
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame.content());
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame);
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// Anything except binary frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
@@ -69,16 +67,13 @@ class NoiseClientTransportHandler extends ChannelDuplexHandler {
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
context.write(Unpooled.wrappedBuffer(noiseBuffer), promise);
} finally {
ReferenceCountUtil.release(plaintext);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
// Clients only write ByteBufs or close the connection on errors, so any other message is unexpected
log.warn("Unexpected object in pipeline: {}", message);
context.write(message, promise);
}
}

View File

@@ -0,0 +1,354 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.southernstorm.noise.protocol.Noise;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.*;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.MessageToMessageCodec;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.ReferenceCountUtil;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
import javax.net.ssl.SSLException;
public class NoiseTunnelClient implements AutoCloseable {
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
private final ServerBootstrap serverBootstrap;
private Channel serverChannel;
public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
public enum FramingType {
WEBSOCKET,
NOISE_DIRECT
}
public static class Builder {
final SocketAddress remoteServerAddress;
NioEventLoopGroup eventLoopGroup;
ECPublicKey serverPublicKey;
FramingType framingType = FramingType.WEBSOCKET;
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
HttpHeaders headers = new DefaultHttpHeaders();
boolean authenticated = false;
ECKeyPair ecKeyPair = null;
UUID accountIdentifier = null;
byte deviceId = 0x00;
boolean useTls;
X509Certificate trustedServerCertificate = null;
Supplier<HAProxyMessage> proxyMessageSupplier = null;
public Builder(
final SocketAddress remoteServerAddress,
final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
this.remoteServerAddress = remoteServerAddress;
this.eventLoopGroup = eventLoopGroup;
this.serverPublicKey = serverPublicKey;
}
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
this.authenticated = true;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.ecKeyPair = ecKeyPair;
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
return this;
}
public Builder setWebsocketUri(final URI websocketUri) {
this.websocketUri = websocketUri;
return this;
}
public Builder setUseTls(X509Certificate trustedServerCertificate) {
this.useTls = true;
this.trustedServerCertificate = trustedServerCertificate;
return this;
}
public Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
this.proxyMessageSupplier = proxyMessageSupplier;
return this;
}
public Builder setHeaders(final HttpHeaders headers) {
this.headers = headers;
return this;
}
public Builder setServerPublicKey(ECPublicKey serverPublicKey) {
this.serverPublicKey = serverPublicKey;
return this;
}
public Builder setFramingType(FramingType framingType) {
this.framingType = framingType;
return this;
}
public NoiseTunnelClient build() {
final List<ChannelHandler> handlers = new ArrayList<>();
if (proxyMessageSupplier != null) {
handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier)));
}
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (trustedServerCertificate != null) {
sslContextBuilder.trustManager(trustedServerCertificate);
}
try {
handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT));
} catch (SSLException e) {
throw new IllegalArgumentException(e);
}
}
// handles the wrapping and unrwrapping the framing layer (websockets or noisedirect)
handlers.addAll(switch (framingType) {
case WEBSOCKET -> websocketHandlerStack(websocketUri, headers);
case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated);
});
final NoiseClientHandshakeHelper helper = authenticated
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
: NoiseClientHandshakeHelper.NK(serverPublicKey);
handlers.add(new NoiseClientHandshakeHandler(helper));
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
// information about why the connection was closed.
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
handlers.add(closeEventHandler);
final NoiseTunnelClient client =
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, fastOpenRequest -> new EstablishRemoteConnectionHandler(
handlers,
authenticated ? new AuthenticatedDevice(accountIdentifier, deviceId) : null,
remoteServerAddress,
fastOpenRequest));
client.start();
return client;
}
}
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
CompletableFuture<CloseFrameEvent> closeEventFuture,
Function<byte[], EstablishRemoteConnectionHandler> handler) {
this.closeEventFuture = closeEventFuture;
this.serverBootstrap = new ServerBootstrap()
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
.channel(LocalServerChannel.class)
.group(eventLoopGroup)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline()
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
// in the handshake.
.addLast(Http2Buffering.handler())
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
// connect to the remote service
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (evt instanceof FastOpenRequestBufferedEvent requestBufferedEvent) {
byte[] fastOpenRequest = ByteBufUtil.getBytes(requestBufferedEvent.fastOpenRequest());
requestBufferedEvent.fastOpenRequest().release();
ctx.pipeline().addLast(handler.apply(fastOpenRequest));
}
super.userEventTriggered(ctx, evt);
}
})
.addLast(new ClientErrorHandler());
}
});
}
private static class UserEventFuture<T> extends ChannelInboundHandlerAdapter {
private final CompletableFuture<T> future = new CompletableFuture<>();
private final Class<T> cls;
UserEventFuture(Class<T> cls) {
this.cls = cls;
}
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (cls.isInstance(evt)) {
future.complete((T) evt);
}
ctx.fireUserEventTriggered(evt);
}
}
public LocalAddress getLocalAddress() {
return (LocalAddress) serverChannel.localAddress();
}
private NoiseTunnelClient start() {
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
return this;
}
@Override
public void close() throws InterruptedException {
serverChannel.close().await();
}
/**
* @return A future that completes when a close frame is observed
*/
public CompletableFuture<CloseFrameEvent> closeFrameFuture() {
return closeEventFuture;
}
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
return List.of(
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
new NoiseDirectFrameCodec(),
new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
ctx.fireChannelActive();
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
try {
final NoiseDirectProtos.Error errorPayload =
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
ctx.fireUserEventTriggered(
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.SERVER));
} finally {
ReferenceCountUtil.release(msg);
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.ERROR) {
final NoiseDirectProtos.Error errorPayload =
NoiseDirectProtos.Error.parseFrom(ByteBufUtil.getBytes(ndf.content()));
ctx.fireUserEventTriggered(
CloseFrameEvent.fromNoiseDirectErrorFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
}
ctx.write(msg, promise);
}
},
new MessageToMessageCodec<NoiseDirectFrame, ByteBuf>() {
boolean noiseHandshakeFinished = false;
@Override
protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List<Object> out) {
final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished
? NoiseDirectFrame.FrameType.DATA
: (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE);
noiseHandshakeFinished = true;
out.add(new NoiseDirectFrame(frameType, msg.retain()));
}
@Override
protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg,
final List<Object> out) {
out.add(msg.content().retain());
}
});
}
private static List<ChannelHandler> websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) {
return List.of(
new HttpClientCodec(),
new HttpObjectAggregator(Noise.MAX_PACKET_LEN),
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
// want to react to them on our own, we need to catch them before they hit that handler.
new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
context.fireUserEventTriggered(
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER));
}
super.channelRead(context, message);
}
},
new WebSocketClientProtocolHandler(websocketUri,
WebSocketVersion.V13,
null,
false,
headers,
Noise.MAX_PACKET_LEN,
10_000),
new ChannelOutboundHandlerAdapter() {
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
context.fireUserEventTriggered(
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT));
}
super.write(context, message, promise);
}
},
new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
}
}
context.fireUserEventTriggered(event);
}
},
new WebsocketPayloadCodec());
}
}

View File

@@ -0,0 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
public record ReadyForNoiseHandshakeEvent() {
}

View File

@@ -0,0 +1,49 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.util.concurrent.Executor;
class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseDirectTunnelServer noiseDirectTunnelServer;
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
noiseDirectTunnelServer = new NoiseDirectTunnelServer(0,
eventLoopGroup,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress);
noiseDirectTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
noiseDirectTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
.setFramingType(NoiseTunnelClient.FramingType.NOISE_DIRECT);
}
}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -18,6 +18,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {

View File

@@ -0,0 +1,237 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeoutException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
private X509Certificate serverTlsCertificate;
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with:
//
// ```shell
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
// ```
private static final String SERVER_CERTIFICATE = """
-----BEGIN CERTIFICATE-----
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
iOr9sHiO8Rn2u0xRKgU5Ig==
-----END CERTIFICATE-----
""";
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
private static final String SERVER_PRIVATE_KEY = """
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
""";
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
final PrivateKey serverTlsPrivateKey;
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
serverTlsPrivateKey =
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
new X509Certificate[]{serverTlsCertificate},
serverTlsPrivateKey,
eventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
recognizedProxySecret);
tlsNoiseWebSocketTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
tlsNoiseWebSocketTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
.setUseTls(serverTlsCertificate);
}
@Test
void getRequestAttributes() throws InterruptedException {
final String remoteAddress = "4.5.6.7";
final String acceptLanguage = "en";
final String userAgent = "Signal-Desktop/1.2.3 Linux";
final HttpHeaders headers = new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add("X-Forwarded-For", remoteAddress)
.add("Accept-Language", acceptLanguage)
.add("User-Agent", userAgent);
try (final NoiseTunnelClient client = anonymous().setHeaders(headers).build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated()
.setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = anonymous()
.setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void rejectIllegalRequests() throws Exception {
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
final TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated",
null, null);
final URI incorrectUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect",
null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
.uri(authenticatedUri)
.PUT(HttpRequest.BodyPublishers.ofString("test"))
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"Non-GET requests should not be allowed");
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(authenticatedUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests without upgrade headers should not be allowed");
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(incorrectUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests to unrecognized URIs should not be allowed");
}
}
}

View File

@@ -0,0 +1,52 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import java.util.concurrent.Executor;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
null,
null,
eventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
recognizedProxySecret);
plaintextNoiseWebSocketTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
plaintextNoiseWebSocketTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey);
}
}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
@@ -19,6 +19,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -34,6 +34,10 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {