Gracefully handle proxy protocol messages at the beginning of TCP connections

This commit is contained in:
Jon Chambers
2024-05-24 09:11:19 -04:00
committed by GitHub
parent 1678045ce4
commit 9ec4f0b2f5
11 changed files with 361 additions and 11 deletions

View File

@@ -8,6 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
@@ -21,6 +23,7 @@ import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@@ -39,6 +42,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
@Nullable private final Supplier<HAProxyMessage> proxyMessageSupplier;
private final List<Object> pendingReads = new ArrayList<>();
@@ -55,7 +59,8 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final byte deviceId,
final HttpHeaders headers,
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) {
final WebSocketCloseListener webSocketCloseListener,
@Nullable Supplier<HAProxyMessage> proxyMessageSupplier) {
this.useTls = useTls;
this.trustedServerCertificate = trustedServerCertificate;
@@ -68,6 +73,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
this.headers = headers;
this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener;
this.proxyMessageSupplier = proxyMessageSupplier;
}
@Override
@@ -78,6 +84,16 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws SSLException {
if (proxyMessageSupplier != null) {
// In a production setting, we'd want some mechanism to remove these handlers after the initial message
// were sent. Since this is just for testing, though, we can tolerate the inefficiency of leaving a
// pair of inert handlers in the pipeline.
channel.pipeline()
.addLast(HAProxyMessageEncoder.INSTANCE)
.addLast(new HAProxyMessageSender(proxyMessageSupplier));
}
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();

View File

@@ -0,0 +1,62 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class HAProxyMessageHandlerTest {
private EmbeddedChannel embeddedChannel;
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler());
}
@Test
void handleHAProxyMessage() throws InterruptedException {
final HAProxyMessage haProxyMessage = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage);
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, haProxyMessage.refCnt());
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
}
@Test
void handleNonHAProxyMessage() throws InterruptedException {
final byte[] bytes = new byte[32];
ThreadLocalRandom.current().nextBytes(bytes);
final ByteBuf message = Unpooled.wrappedBuffer(bytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message);
embeddedChannel.flushInbound();
writeFuture.await();
assertEquals(1, embeddedChannel.inboundMessages().size());
assertEquals(message, embeddedChannel.inboundMessages().poll());
assertEquals(1, message.refCnt());
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
}
}

View File

@@ -0,0 +1,28 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.util.function.Supplier;
class HAProxyMessageSender extends ChannelInboundHandlerAdapter {
private final Supplier<HAProxyMessage> messageSupplier;
HAProxyMessageSender(final Supplier<HAProxyMessage> messageSupplier) {
this.messageSupplier = messageSupplier;
}
@Override
public void handlerAdded(final ChannelHandlerContext context) {
if (context.channel().isActive()) {
context.writeAndFlush(messageSupplier.get());
}
}
@Override
public void channelActive(final ChannelHandlerContext context) {
context.writeAndFlush(messageSupplier.get());
context.fireChannelActive();
}
}

View File

@@ -11,6 +11,8 @@ import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.UUID;
import java.util.function.Supplier;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.http.HttpHeaders;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
@@ -34,6 +36,7 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
final HttpHeaders headers,
final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
@Nullable final Supplier<HAProxyMessage> proxyMessageSupplier,
final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) {
@@ -54,7 +57,8 @@ class NoiseWebSocketTunnelClient implements AutoCloseable {
deviceId,
headers,
remoteServerAddress,
webSocketCloseListener));
webSocketCloseListener,
proxyMessageSupplier));
}
});
}

View File

@@ -16,6 +16,10 @@ import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream;
@@ -54,6 +58,8 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
@@ -234,9 +240,10 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
void connectAuthenticated() throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient()) {
@ParameterizedTest
@ValueSource(booleans = { true, false })
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), includeProxyMessage)) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
@@ -251,8 +258,9 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
@Test
void connectAuthenticatedPlaintext() throws InterruptedException {
@ParameterizedTest
@ValueSource(booleans = { true, false })
void connectAuthenticatedPlaintext(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseWebSocketTunnelClient client = new NoiseWebSocketTunnelClient(
tlsNoiseWebSocketTunnelServer.getLocalAddress(),
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
@@ -264,6 +272,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
nioEventLoopGroup,
WebSocketCloseListener.NOOP_LISTENER)
.start()) {
@@ -289,7 +298,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseWebSocketTunnelClient client =
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders(), false)) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
@@ -369,6 +378,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
@@ -443,6 +453,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
@@ -602,12 +613,13 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
throws InterruptedException {
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders(), false);
}
private NoiseWebSocketTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
final HttpHeaders headers,
final boolean includeProxyMessage) throws InterruptedException {
return new NoiseWebSocketTunnelClient(tlsNoiseWebSocketTunnelServer.getLocalAddress(),
NoiseWebSocketTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
@@ -619,6 +631,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
headers,
true,
serverTlsCertificate,
includeProxyMessage ? NoiseWebSocketTunnelServerIntegrationTest::buildProxyMessage : null,
nioEventLoopGroup,
webSocketCloseListener)
.start();
@@ -642,8 +655,14 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
headers,
true,
serverTlsCertificate,
null,
nioEventLoopGroup,
webSocketCloseListener)
.start();
}
private static HAProxyMessage buildProxyMessage() {
return new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443);
}
}

View File

@@ -0,0 +1,108 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.util.HexFormat;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class ProxyProtocolDetectionHandlerTest {
private EmbeddedChannel embeddedChannel;
private static final byte[] PROXY_V2_MESSAGE_BYTES =
HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb");
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler());
}
@Test
void singlePacketProxyMessage() throws InterruptedException {
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES));
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
}
@Test
void multiPacketProxyMessage() throws InterruptedException {
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0,
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
embeddedChannel.flushInbound();
firstWriteFuture.await();
secondWriteFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
}
@Test
void singlePacketNonProxyMessage() throws InterruptedException {
final byte[] nonProxyProtocolMessage = new byte[32];
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage));
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
assertInstanceOf(ByteBuf.class, inboundMessage);
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
}
@Test
void multiPacketNonProxyMessage() throws InterruptedException {
final byte[] nonProxyProtocolMessage = new byte[32];
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0,
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
embeddedChannel.flushInbound();
firstWriteFuture.await();
secondWriteFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
assertInstanceOf(ByteBuf.class, inboundMessage);
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
}
}

View File

@@ -13,7 +13,6 @@ import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;