Add NoiseDirect framing protocol

This commit is contained in:
ravi-signal
2025-04-30 15:05:05 -05:00
committed by GitHub
parent e285bf1a52
commit 0398e02690
63 changed files with 2111 additions and 1450 deletions

View File

@@ -154,7 +154,7 @@ import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;

View File

@@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that an attempt to authenticate a remote client failed for some reason.
*/
class ClientAuthenticationException extends Exception {
public class ClientAuthenticationException extends NoStackTraceException {
}

View File

@@ -1,62 +1,46 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import javax.crypto.BadPaddingException;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. If the client has completed a
* WebSocket handshake, the error handler will send appropriate WebSocket closure codes to the client in an attempt to
* identify the problem. If the client has not completed a WebSocket handshake, the handler simply closes the
* connection.
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
*/
class ErrorHandler extends ChannelInboundHandlerAdapter {
private boolean websocketHandshakeComplete = false;
public class ErrorHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
setWebsocketHandshakeComplete();
}
context.fireUserEventTriggered(event);
}
protected void setWebsocketHandshakeComplete() {
this.websocketHandshakeComplete = true;
}
private static OutboundCloseErrorMessage UNAUTHENTICATED_CLOSE = new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.AUTHENTICATION_ERROR,
"Not authenticated");
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_ERROR,
"Noise encryption error");
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
if (websocketHandshakeComplete) {
final WebSocketCloseStatus webSocketCloseStatus = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.toWebSocketCloseStatus(e.getMessage());
case ClientAuthenticationException ignored -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.toWebSocketCloseStatus("Not authenticated");
case BadPaddingException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
case NoiseException ignored -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.toWebSocketCloseStatus("Noise encryption error");
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
e.getMessage());
case ClientAuthenticationException ignored -> UNAUTHENTICATED_CLOSE;
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
default -> {
log.warn("An unexpected exception reached the end of the pipeline", cause);
yield WebSocketCloseStatus.INTERNAL_SERVER_ERROR;
yield new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
cause.getMessage());
}
};
context.writeAndFlush(new CloseWebSocketFrame(webSocketCloseStatus))
context.writeAndFlush(closeMessage)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error occurred before websocket handshake complete", cause);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
context.close();
}
}
}

View File

@@ -7,8 +7,6 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.List;
@@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
* server.
*/
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
private final GrpcClientConnectionManager grpcClientConnectionManager;
@@ -79,7 +77,9 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
remoteChannelContext.channel()
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
remoteChannelContext.pipeline()
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));

View File

@@ -7,7 +7,6 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
@@ -63,6 +62,9 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
private static OutboundCloseErrorMessage SERVER_CLOSED =
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/**
@@ -161,9 +163,7 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
@VisibleForTesting
@@ -198,16 +198,16 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
}
/**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* Handles receipt of a handshake message and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
*
* @param channel the channel that completed a WebSocket handshake
* @param channel the channel where the handshake was initiated
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
static void handleHandshakeComplete(final Channel channel,
public static void handleHandshakeInitiated(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
@@ -227,11 +227,10 @@ public class GrpcClientConnectionManager implements DisconnectionRequestListener
}
/**
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
*
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
* server
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
* @param remoteChannel the channel from the remote client to the Noise tunnel
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
*/
void handleConnectionEstablished(final LocalChannel localChannel,

View File

@@ -4,7 +4,7 @@
*/
package org.whispersystems.textsecuregcm.grpc.net;
enum HandshakePattern {
public enum HandshakePattern {
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");

View File

@@ -17,7 +17,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
* Once the handler receives the handshake initiator message, it will fire a {@link NoiseIdentityDeterminedEvent}
* indicating that initiator connected anonymously.
*/
class NoiseAnonymousHandler extends NoiseHandler {
public class NoiseAnonymousHandler extends NoiseHandler {
public NoiseAnonymousHandler(final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.NK, ecKeyPair));

View File

@@ -32,11 +32,11 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
* <p>
* As soon as the handler authenticates the caller, it will fire a {@link NoiseIdentityDeterminedEvent}.
*/
class NoiseAuthenticatedHandler extends NoiseHandler {
public class NoiseAuthenticatedHandler extends NoiseHandler {
private final ClientPublicKeysManager clientPublicKeysManager;
NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
public NoiseAuthenticatedHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair) {
super(new NoiseHandshakeHelper(HandshakePattern.IK, ecKeyPair));
this.clientPublicKeysManager = clientPublicKeysManager;

View File

@@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
* format or a general encryption error).
*/
class NoiseException extends Exception {
public class NoiseException extends NoStackTraceException {
public NoiseException(final String message) {
super(message);
}

View File

@@ -26,13 +26,14 @@ import javax.crypto.ShortBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* A bidirectional {@link io.netty.channel.ChannelHandler} that establishes a noise session with an initiator, decrypts
* inbound messages, and encrypts outbound messages
*/
abstract class NoiseHandler extends ChannelDuplexHandler {
public abstract class NoiseHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
@@ -82,17 +83,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof BinaryWebSocketFrame frame) {
if (frame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
final String error = "Invalid noise message length " + frame.content().readableBytes();
if (message instanceof ByteBuf frame) {
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
final String error = "Invalid noise message length " + frame.readableBytes();
throw state == State.HANDSHAKE ? new NoiseHandshakeException(error) : new NoiseException(error);
}
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
handleInboundMessage(context, ByteBufUtil.getBytes(frame.content()));
handleInboundMessage(context, ByteBufUtil.getBytes(frame));
} else {
// Anything except binary WebSocket frames should have been filtered out of the pipeline by now; treat this as an
// error
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} catch (Exception e) {
@@ -122,7 +122,7 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
// Now that we've authenticated, write the handshake response
byte[] handshakeMessage = handshakeHelper.write(EmptyArrays.EMPTY_BYTES);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(handshakeMessage)))
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
@@ -193,16 +193,16 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
}
pc.finish(promise);
} finally {
ReferenceCountUtil.release(byteBuf);
}
} else {
if (!(message instanceof WebSocketFrame)) {
// Downstream handlers may write WebSocket frames that don't need to be encrypted (e.g. "close" frames that
// get issued in response to exceptions)
if (!(message instanceof OutboundCloseErrorMessage)) {
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
// that get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);

View File

@@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
* a general encryption error).
*/
class NoiseHandshakeException extends Exception {
public class NoiseHandshakeException extends NoStackTraceException {
public NoiseHandshakeException(final String message) {
super(message);

View File

@@ -10,4 +10,4 @@ import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
* type that performs authentication
*/
record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}
public record NoiseIdentityDeterminedEvent(Optional<AuthenticatedDevice> authenticatedDevice) {}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
/**
* An error written to the outbound pipeline that indicates the connection should be closed
*/
public record OutboundCloseErrorMessage(Code code, String message) {
public enum Code {
/**
* The server decided to close the connection. This could be because the server is going away, or it could be
* because the credentials for the connected client have been updated.
*/
SERVER_CLOSED,
/**
* There was a noise decryption error after the noise session was established
*/
NOISE_ERROR,
/**
* There was an error establishing the noise handshake
*/
NOISE_HANDSHAKE_ERROR,
/**
* The provided credentials were not valid
*/
AUTHENTICATION_ERROR,
INTERNAL_SERVER_ERROR
}
}

View File

@@ -8,7 +8,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
/**
* A proxy handler writes all data read from one channel to another peer channel.
*/
class ProxyHandler extends ChannelInboundHandlerAdapter {
public class ProxyHandler extends ChannelInboundHandlerAdapter {
private final Channel peerChannel;

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
/**
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
* <p>
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
* handshake response, and then the subsequent messages are all data frame payloads.
*/
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
ReferenceCountUtil.release(msg);
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
}
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
public class NoiseDirectFrame extends DefaultByteBufHolder {
static final byte VERSION = 0x00;
private final FrameType frameType;
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
super(data);
this.frameType = frameType;
}
public FrameType frameType() {
return frameType;
}
public byte versionedFrameTypeByte() {
final byte frameBits = frameType().getFrameBits();
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
}
public enum FrameType {
/**
* The payload is the initiator message or the responder message for a Noise NK handshake. If established, the
* session will be unauthenticated.
*/
NK_HANDSHAKE((byte) 1),
/**
* The payload is the initiator message or the responder message for a Noise IK handshake. If established, the
* session will be authenticated.
*/
IK_HANDSHAKE((byte) 2),
/**
* The payload is an encrypted noise packet.
*/
DATA((byte) 3),
/**
* A framing layer error occurred. The payload carries error details.
*/
ERROR((byte) 4);
private final byte frameType;
FrameType(byte frameType) {
if (frameType != (0x0F & frameType)) {
throw new IllegalStateException("Frame type must fit in 4 bits");
}
this.frameType = frameType;
}
public byte getFrameBits() {
return frameType;
}
public boolean isHandshake() {
return switch (this) {
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
case DATA, ERROR -> false;
};
}
}
}

View File

@@ -0,0 +1,90 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
/**
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
*/
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof ByteBuf byteBuf) {
try {
ctx.fireChannelRead(deserialize(byteBuf));
} catch (Exception e) {
ReferenceCountUtil.release(byteBuf);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
try {
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
ctx.writeAndFlush(serialized, promise);
} finally {
ReferenceCountUtil.release(noiseDirectFrame);
}
} else {
ctx.write(msg, promise);
}
}
private ByteBuf serialize(
final ChannelHandlerContext ctx,
final NoiseDirectFrame noiseDirectFrame) {
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
}
// 1 version/frametype byte, 2 length bytes, content
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
byteBuf.writeBytes(noiseDirectFrame.content());
return byteBuf;
}
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
final byte versionAndFrameByte = byteBuf.readByte();
final int version = (versionAndFrameByte & 0xF0) >> 4;
if (version != NoiseDirectFrame.VERSION) {
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
}
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
case 3 -> NoiseDirectFrame.FrameType.DATA;
case 4 -> NoiseDirectFrame.FrameType.ERROR;
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
};
final int length = Short.toUnsignedInt(byteBuf.readShort());
if (length != byteBuf.readableBytes()) {
throw new IllegalArgumentException(
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
}
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
}
}

View File

@@ -0,0 +1,75 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.io.IOException;
import java.net.InetSocketAddress;
/**
* Waits for a Handshake {@link NoiseDirectFrame} and then installs a {@link NoiseDirectDataFrameCodec} and
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} and removes itself
*/
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
public NoiseDirectHandshakeSelector(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
try {
// We've received an inbound handshake frame so we know what kind of NoiseHandler we need (authenticated or
// anonymous). We construct it here, and then remember the handshake type so we can annotate our handshake
// response with the correct frame type whenever we receive it.
final ChannelDuplexHandler noiseHandler = switch (frame.frameType()) {
case DATA, ERROR ->
throw new NoiseHandshakeException("Invalid frame type for first message " + frame.frameType());
case IK_HANDSHAKE -> new NoiseAuthenticatedHandler(clientPublicKeysManager, ecKeyPair);
case NK_HANDSHAKE -> new NoiseAnonymousHandler(ecKeyPair);
};
if (ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
// TODO: Provide connection metadata / headers in handshake payload
GrpcClientConnectionManager.handleHandshakeInitiated(ctx.channel(),
inetSocketAddress.getAddress(),
"NoiseDirect",
"");
} else {
throw new IOException("Could not determine remote address");
}
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
// for network serialization. Note that we need to install the Data frame handler before firing the read,
// because we may receive an outbound message from the noiseHandler
ctx.pipeline().addAfter(ctx.name(), null, noiseHandler);
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
ctx.fireChannelRead(frame.content());
} catch (Exception e) {
ReferenceCountUtil.release(msg);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
}

View File

@@ -0,0 +1,39 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
/**
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
* written, the channel is closed
*/
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
final NoiseDirectProtos.Error.Type type = switch (err.code()) {
case SERVER_CLOSED -> NoiseDirectProtos.Error.Type.UNAVAILABLE;
case NOISE_ERROR -> NoiseDirectProtos.Error.Type.ENCRYPTION_ERROR;
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.Error.Type.HANDSHAKE_ERROR;
case AUTHENTICATION_ERROR -> NoiseDirectProtos.Error.Type.AUTHENTICATION_ERROR;
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.Error.Type.INTERNAL_ERROR;
};
final NoiseDirectProtos.Error proto = NoiseDirectProtos.Error.newBuilder()
.setType(type)
.setMessage(err.message())
.build();
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
proto.writeTo(new ByteBufOutputStream(byteBuf));
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.ERROR, byteBuf))
.addListener(ChannelFutureListener.CLOSE);
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -0,0 +1,90 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
import io.dropwizard.lifecycle.Managed;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import java.net.InetSocketAddress;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
* binary framing protocol) and passes it through to a local gRPC server.
*/
public class NoiseDirectTunnelServer implements Managed {
private final ServerBootstrap bootstrap;
private ServerSocketChannel channel;
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
public NoiseDirectTunnelServer(final int port,
final NioEventLoopGroup eventLoopGroup,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) {
this.bootstrap = new ServerBootstrap()
.group(eventLoopGroup)
.channel(NioServerSocketChannel.class)
.localAddress(port)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) {
socketChannel.pipeline()
.addLast(new ProxyProtocolDetectionHandler())
.addLast(new HAProxyMessageHandler());
socketChannel.pipeline()
// frame byte followed by a 2-byte length field
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
// Parses NoiseDirectFrames from wire bytes and vice versa
.addLast(new NoiseDirectFrameCodec())
// Turn generic OutboundCloseErrorMessages into noise direct error frames
.addLast(new NoiseDirectOutboundErrorHandler())
// Waits for the handshake to finish and then replaces itself with a NoiseDirectFrameCodec and a
// NoiseHandler to handle noise encryption/decryption
.addLast(new NoiseDirectHandshakeSelector(clientPublicKeysManager, ecKeyPair))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(
grpcClientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
.addLast(new ErrorHandler());
}
});
}
@VisibleForTesting
public InetSocketAddress getLocalAddress() {
return channel.localAddress();
}
@Override
public void start() throws InterruptedException {
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
}
@Override
public void stop() throws InterruptedException {
if (channel != null) {
channel.close().await();
}
}
}

View File

@@ -1,12 +1,11 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
enum ApplicationWebSocketCloseReason {
NOISE_HANDSHAKE_ERROR(4001),
CLIENT_AUTHENTICATION_ERROR(4002),
NOISE_ENCRYPTION_ERROR(4003),
REAUTHENTICATION_REQUIRED(4004);
NOISE_ENCRYPTION_ERROR(4003);
private final int statusCode;
@@ -17,8 +16,4 @@ enum ApplicationWebSocketCloseReason {
public int getStatusCode() {
return statusCode;
}
WebSocketCloseStatus toWebSocketCloseStatus(final String reason) {
return new WebSocketCloseStatus(statusCode, reason);
}
}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
@@ -28,6 +28,7 @@ import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.*;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
@@ -103,7 +104,10 @@ public class NoiseWebSocketTunnelServer implements Managed {
// request and passed it down the pipeline
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
.addLast(new WebSocketServerProtocolHandler("/", true))
// Turn generic OutboundCloseErrorMessages into websocket close frames
.addLast(new WebSocketOutboundErrorHandler())
.addLast(new RejectUnsupportedMessagesHandler())
.addLast(new WebsocketPayloadCodec())
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
// a WebSocket handshake has been completed
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, recognizedProxySecret))

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;

View File

@@ -0,0 +1,64 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import javax.crypto.BadPaddingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ClientAuthenticationException;
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
*/
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
private boolean websocketHandshakeComplete = false;
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
setWebsocketHandshakeComplete();
}
context.fireUserEventTriggered(event);
}
protected void setWebsocketHandshakeComplete() {
this.websocketHandshakeComplete = true;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
if (websocketHandshakeComplete) {
final int status = switch (err.code()) {
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
case AUTHENTICATION_ERROR -> ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode();
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
};
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error {} occurred before websocket handshake complete", err);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
ctx.close();
}
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -1,4 +1,4 @@
package org.whispersystems.textsecuregcm.grpc.net;
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
@@ -20,6 +20,9 @@ import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAnonymousHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseAuthenticatedHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
@@ -74,7 +77,7 @@ class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
preferredRemoteAddress = maybePreferredRemoteAddress.get();
}
GrpcClientConnectionManager.handleHandshakeComplete(context.channel(),
GrpcClientConnectionManager.handleHandshakeInitiated(context.channel(),
preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));

View File

@@ -0,0 +1,38 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
/**
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
*/
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof BinaryWebSocketFrame frame) {
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new BinaryWebSocketFrame(bb), promise);
} else {
ctx.write(msg, promise);
}
}
}