Introduce a (dormant) Noise/WebSocket for future client/server communication

This commit is contained in:
Jon Chambers
2024-02-23 11:42:42 -05:00
committed by GitHub
parent d2716fe5cf
commit a5774bf6ff
45 changed files with 3262 additions and 84 deletions

View File

@@ -0,0 +1,21 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.util.ResourceLeakDetector;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
abstract class AbstractLeakDetectionTest {
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
@BeforeAll
static void setLeakDetectionLevel() {
originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
}
@AfterAll
static void restoreLeakDetectionLevel() {
ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
}
}

View File

@@ -0,0 +1,94 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
abstract class AbstractNoiseClientHandler extends ChannelInboundHandlerAdapter {
private final ECPublicKey rootPublicKey;
private final HandshakeState handshakeState;
AbstractNoiseClientHandler(final ECPublicKey rootPublicKey) {
this.rootPublicKey = rootPublicKey;
try {
handshakeState = new HandshakeState(getNoiseProtocolName(), HandshakeState.INITIATOR);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Unsupported Noise algorithm: " + getNoiseProtocolName(), e);
}
}
protected abstract String getNoiseProtocolName();
protected abstract void startHandshake();
protected HandshakeState getHandshakeState() {
return handshakeState;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
startHandshake();
final byte[] ephemeralKeyMessage = new byte[32];
handshakeState.writeMessage(ephemeralKeyMessage, 0, null, 0, 0);
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
}
}
super.userEventTriggered(context, event);
}
protected void handleServerStaticKeyMessage(final ChannelHandlerContext context, final BinaryWebSocketFrame frame)
throws NoiseHandshakeException {
// The frame is coming right off the wire and so will be a direct buffer not backed by an array; copy it to a heap
// buffer so we can Noise at it.
final ByteBuf keyMaterialBuffer = context.alloc().heapBuffer(frame.content().readableBytes());
final byte[] serverPublicKeySignature = new byte[64];
try {
frame.content().readBytes(keyMaterialBuffer);
final int payloadBytesRead =
handshakeState.readMessage(keyMaterialBuffer.array(), keyMaterialBuffer.arrayOffset(), keyMaterialBuffer.readableBytes(), serverPublicKeySignature, 0);
if (payloadBytesRead != 64) {
throw new NoiseHandshakeException("Unexpected signature length");
}
} catch (final ShortBufferException e) {
throw new NoiseHandshakeException("Unexpected signature length");
} catch (final BadPaddingException e) {
throw new NoiseHandshakeException("Invalid keys");
} finally {
keyMaterialBuffer.release();
}
final byte[] serverPublicKey = new byte[32];
handshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
if (!rootPublicKey.verifySignature(serverPublicKey, serverPublicKeySignature)) {
throw new NoiseHandshakeException("Invalid server public key signature");
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) throws Exception {
handshakeState.destroy();
}
}

View File

@@ -0,0 +1,141 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
abstract class AbstractNoiseHandshakeHandlerTest extends AbstractLeakDetectionTest {
private ECPublicKey rootPublicKey;
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
private EmbeddedChannel embeddedChannel;
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@Nullable
private NoiseHandshakeCompleteEvent handshakeCompleteEvent = null;
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
handshakeCompleteEvent = noiseHandshakeCompleteEvent;
} else {
context.fireUserEventTriggered(event);
}
}
@Nullable
public NoiseHandshakeCompleteEvent getHandshakeCompleteEvent() {
return handshakeCompleteEvent;
}
}
@BeforeEach
void setUp() {
final ECKeyPair rootKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
rootPublicKey = rootKeyPair.getPublicKey();
final byte[] serverPublicKeySignature =
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes());
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
embeddedChannel =
new EmbeddedChannel(getHandler(serverKeyPair, serverPublicKeySignature), noiseHandshakeCompleteHandler);
}
@AfterEach
void tearDown() {
embeddedChannel.close();
}
protected EmbeddedChannel getEmbeddedChannel() {
return embeddedChannel;
}
protected ECPublicKey getRootPublicKey() {
return rootPublicKey;
}
@Nullable
protected NoiseHandshakeCompleteEvent getNoiseHandshakeCompleteEvent() {
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
}
protected abstract AbstractNoiseHandshakeHandler getHandler(final ECKeyPair serverKeyPair, final byte[] serverPublicKeySignature);
@Test
void handleInvalidInitialMessage() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(content)).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
assertEquals(0, content.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
final BinaryWebSocketFrame[] frames = new BinaryWebSocketFrame[7];
for (int i = 0; i < frames.length; i++) {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
frames[i] = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
embeddedChannel.writeOneInbound(frames[i]).await();
}
for (final BinaryWebSocketFrame frame : frames) {
assertEquals(0, frame.refCnt());
}
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleNonWebSocketBinaryFrame() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf message = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
assertEquals(0, message.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
}

View File

@@ -0,0 +1,21 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.AuthenticationTypeGrpc;
import org.signal.chat.rpc.GetAuthenticatedRequest;
import org.signal.chat.rpc.GetAuthenticatedResponse;
public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
private final boolean authenticated;
public AuthenticationTypeService(final boolean authenticated) {
this.authenticated = authenticated;
}
@Override
public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver<GetAuthenticatedResponse> responseObserver) {
responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
responseObserver.onCompleted();
}
}

View File

@@ -0,0 +1,18 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
class ClientErrorHandler extends ErrorHandler {
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
setWebsocketHandshakeComplete();
}
}
super.userEventTriggered(context, event);
}
}

View File

@@ -0,0 +1,141 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.Noise;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final X509Certificate trustedServerCertificate;
private final URI websocketUri;
private final boolean authenticated;
@Nullable private final ECKeyPair ecKeyPair;
private final ECPublicKey rootPublicKey;
@Nullable private final UUID accountIdentifier;
private final byte deviceId;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
private final List<Object> pendingReads = new ArrayList<>();
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler(
final X509Certificate trustedServerCertificate,
final URI websocketUri,
final boolean authenticated,
@Nullable final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) {
this.trustedServerCertificate = trustedServerCertificate;
this.websocketUri = websocketUri;
this.authenticated = authenticated;
this.ecKeyPair = ecKeyPair;
this.rootPublicKey = rootPublicKey;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener;
}
@Override
public void handlerAdded(final ChannelHandlerContext localContext) {
new Bootstrap()
.channel(NioSocketChannel.class)
.group(localContext.channel().eventLoop())
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws SSLException {
channel.pipeline()
.addLast(SslContextBuilder
.forClient()
.trustManager(trustedServerCertificate)
.build()
.newHandler(channel.alloc()))
.addLast(new HttpClientCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
// want to react to them on our own, we need to catch them before they hit that handler.
.addLast(new InboundCloseWebSocketFrameHandler(webSocketCloseListener))
.addLast(new WebSocketClientProtocolHandler(websocketUri,
WebSocketVersion.V13,
null,
false,
new DefaultHttpHeaders(),
Noise.MAX_PACKET_LEN,
10_000))
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))
.addLast(authenticated
? new NoiseXXClientHandshakeHandler(ecKeyPair, rootPublicKey, accountIdentifier, deviceId)
: new NoiseNXClientHandshakeHandler(rootPublicKey))
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
throws Exception {
if (event instanceof NoiseHandshakeCompleteEvent) {
remoteContext.pipeline()
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
pendingReads.forEach(localContext::fireChannelRead);
pendingReads.clear();
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
}
super.userEventTriggered(remoteContext, event);
}
})
.addLast(new ClientErrorHandler());
}
})
.connect(remoteServerAddress)
.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
// Close the local connection if the remote channel closes and vice versa
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
} else {
localContext.close();
}
});
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
pendingReads.add(message);
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
}

View File

@@ -0,0 +1,23 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
class InboundCloseWebSocketFrameHandler extends ChannelInboundHandlerAdapter {
private final WebSocketCloseListener webSocketCloseListener;
public InboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
webSocketCloseListener.handleWebSocketClosedByServer(closeWebSocketFrame.statusCode());
}
super.channelRead(context, message);
}
}

View File

@@ -0,0 +1,47 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.util.Optional;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class NoiseNXClientHandshakeHandler extends AbstractNoiseClientHandler {
private boolean receivedServerStaticKeyMessage = false;
NoiseNXClientHandshakeHandler(final ECPublicKey rootPublicKey) {
super(rootPublicKey);
}
@Override
protected String getNoiseProtocolName() {
return NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME;
}
@Override
protected void startHandshake() {
getHandshakeState().start();
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof BinaryWebSocketFrame frame) {
try {
// Don't process additional messages if we're just waiting to close because the handshake failed
if (receivedServerStaticKeyMessage) {
return;
}
receivedServerStaticKeyMessage = true;
handleServerStaticKeyMessage(context, frame);
context.pipeline().replace(this, null, new NoiseStreamHandler(getHandshakeState().split()));
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
}

View File

@@ -0,0 +1,84 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
class NoiseNXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
@Override
protected NoiseNXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
final byte[] serverPublicKeySignature) {
return new NoiseNXHandshakeHandler(serverKeyPair, serverPublicKeySignature);
}
@Test
void handleCompleteHandshake()
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException, BadPaddingException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class));
final HandshakeState clientHandshakeState =
new HandshakeState(NoiseNXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
clientHandshakeState.start();
{
final byte[] ephemeralKeyMessageBytes = new byte[32];
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
}
{
assertEquals(1, embeddedChannel.outboundMessages().size());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
final byte[] serverPublicKeySignature = new byte[64];
final int payloadLength =
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
assertEquals(serverPublicKeySignature.length, payloadLength);
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
}
assertEquals(new NoiseHandshakeCompleteEvent(Optional.empty()), getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseNXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
}
}

View File

@@ -0,0 +1,135 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import static org.junit.jupiter.api.Assertions.*;
class NoiseStreamHandlerTest extends AbstractLeakDetectionTest {
private CipherStatePair clientCipherStatePair;
private EmbeddedChannel embeddedChannel;
// We use an NN handshake for this test just because it's a little shorter and easier to set up
private static final String NOISE_PROTOCOL_NAME = "Noise_NN_25519_ChaChaPoly_BLAKE2b";
@BeforeEach
void setUp() throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException {
final HandshakeState clientHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
final HandshakeState serverHandshakeState = new HandshakeState(NOISE_PROTOCOL_NAME, HandshakeState.RESPONDER);
clientHandshakeState.start();
serverHandshakeState.start();
final byte[] clientEphemeralKeyMessage = new byte[32];
assertEquals(clientEphemeralKeyMessage.length,
clientHandshakeState.writeMessage(clientEphemeralKeyMessage, 0, null, 0, 0));
serverHandshakeState.readMessage(clientEphemeralKeyMessage, 0, clientEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
// 32 bytes of key material plus a 16-byte MAC
final byte[] serverEphemeralKeyMessage = new byte[48];
assertEquals(serverEphemeralKeyMessage.length,
serverHandshakeState.writeMessage(serverEphemeralKeyMessage, 0, null, 0, 0));
clientHandshakeState.readMessage(serverEphemeralKeyMessage, 0, serverEphemeralKeyMessage.length, EmptyArrays.EMPTY_BYTES, 0);
clientCipherStatePair = clientHandshakeState.split();
embeddedChannel = new EmbeddedChannel(new NoiseStreamHandler(serverHandshakeState.split()));
clientHandshakeState.destroy();
serverHandshakeState.destroy();
}
@Test
void channelRead() throws ShortBufferException, InterruptedException {
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ciphertext));
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
assertEquals(0, ciphertextFrame.refCnt());
final ByteBuf decryptedPlaintextBuffer = (ByteBuf) embeddedChannel.inboundMessages().poll();
assertNotNull(decryptedPlaintextBuffer);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
final byte[] decryptedPlaintext = ByteBufUtil.getBytes(decryptedPlaintextBuffer);
decryptedPlaintextBuffer.release();
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void channelReadBadCiphertext() throws InterruptedException {
final byte[] bogusCiphertext = new byte[32];
ThreadLocalRandom.current().nextBytes(bogusCiphertext);
final BinaryWebSocketFrame ciphertextFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(bogusCiphertext));
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
assertEquals(0, ciphertextFrame.refCnt());
assertFalse(readCiphertextFuture.isSuccess());
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelReadUnexpectedMessageType() throws InterruptedException {
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
assertFalse(readFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void write() throws InterruptedException, ShortBufferException, BadPaddingException {
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.await().isSuccess());
assertEquals(0, plaintextBuffer.refCnt());
final BinaryWebSocketFrame ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
assertNotNull(ciphertextFrame);
assertTrue(embeddedChannel.outboundMessages().isEmpty());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
ciphertextFrame.release();
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void writeUnexpectedMessageType() throws InterruptedException {
final Object unexpectedMessaged = new Object();
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
assertTrue(writeFuture.await().isSuccess());
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
assertTrue(embeddedChannel.outboundMessages().isEmpty());
}
}

View File

@@ -0,0 +1,89 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.UUID;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
class NoiseXXClientHandshakeHandler extends AbstractNoiseClientHandler {
private final ECKeyPair ecKeyPair;
private final UUID accountIdentifier;
private final byte deviceId;
private boolean receivedServerStaticKeyMessage = false;
NoiseXXClientHandshakeHandler(final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
final UUID accountIdentifier,
final byte deviceId) {
super(rootPublicKey);
this.ecKeyPair = ecKeyPair;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
}
@Override
protected String getNoiseProtocolName() {
return NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME;
}
@Override
protected void startHandshake() {
final HandshakeState handshakeState = getHandshakeState();
// Noise-java derives the public key from the private key, so we just need to set the private key
handshakeState.getLocalKeyPair().setPrivateKey(ecKeyPair.getPrivateKey().serialize(), 0);
handshakeState.start();
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException, ShortBufferException {
if (message instanceof BinaryWebSocketFrame frame) {
try {
// Don't process additional messages if the handshake failed and we're just waiting to close
if (receivedServerStaticKeyMessage) {
return;
}
receivedServerStaticKeyMessage = true;
handleServerStaticKeyMessage(context, frame);
final ByteBuffer clientIdentityBuffer = ByteBuffer.allocate(17);
clientIdentityBuffer.putLong(accountIdentifier.getMostSignificantBits());
clientIdentityBuffer.putLong(accountIdentifier.getLeastSignificantBits());
clientIdentityBuffer.put(deviceId);
clientIdentityBuffer.flip();
final HandshakeState handshakeState = getHandshakeState();
// We're sending two 32-byte keys plus the client identity payload
final byte[] staticKeyAndIdentityMessage = new byte[64 + clientIdentityBuffer.remaining()];
handshakeState.writeMessage(
staticKeyAndIdentityMessage, 0, clientIdentityBuffer.array(), 0, clientIdentityBuffer.remaining());
context.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(staticKeyAndIdentityMessage)))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
context.pipeline().replace(this, null, new NoiseStreamHandler(handshakeState.split()));
context.fireUserEventTriggered(new NoiseHandshakeCompleteEvent(Optional.empty()));
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
}

View File

@@ -0,0 +1,454 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.EmptyArrays;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
class NoiseXXHandshakeHandlerTest extends AbstractNoiseHandshakeHandlerTest {
private ClientPublicKeysManager clientPublicKeysManager;
@Override
@BeforeEach
void setUp() {
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
super.setUp();
}
@Override
protected NoiseXXHandshakeHandler getHandler(final ECKeyPair serverKeyPair,
final byte[] serverPublicKeySignature) {
return new NoiseXXHandshakeHandler(clientPublicKeysManager, serverKeyPair, serverPublicKeySignature);
}
@Test
void handleCompleteHandshake()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
}
@Test
void handleCompleteHandshakeMissingIdentityInformation()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
{
final byte[] clientStaticKeyMessageBytes = new byte[64];
final int messageLength =
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, EmptyArrays.EMPTY_BYTES, 0, 0);
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
final ChannelFuture writeClientStaticKeyMessageFuture =
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeMalformedIdentityInformation()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
{
final byte[] clientStaticKeyMessageBytes = new byte[96];
final int messageLength =
clientHandshakeState.writeMessage(clientStaticKeyMessageBytes, 0, new byte[32], 0, 32);
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
final ChannelFuture writeClientStaticKeyMessageFuture =
getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await();
assertFalse(writeClientStaticKeyMessageFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeClientStaticKeyMessageFuture.cause());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeUnrecognizedDevice()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakePublicKeyMismatch()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertThrows(ClientAuthenticationException.class, embeddedChannel::checkException);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeBufferedReads()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
final ByteBuf[] additionalMessages = new ByteBuf[4];
final CipherState senderState = clientHandshakeState.split().getSender();
try {
for (int i = 0; i < additionalMessages.length; i++) {
final byte[] contentBytes = new byte[32];
ThreadLocalRandom.current().nextBytes(contentBytes);
// Copy the "plaintext" portion of the content bytes for future assertions
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
// tag
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
assertTrue(
embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await()
.isSuccess());
}
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertEquals(new NoiseHandshakeCompleteEvent(Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))),
getNoiseHandshakeCompleteEvent());
assertNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should remove self from pipeline after successful handshake");
assertNotNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Handshake handler should insert a Noise stream handler after successful handshake");
for (final ByteBuf additionalMessage : additionalMessages) {
assertEquals(additionalMessage, embeddedChannel.inboundMessages().poll(),
"Buffered message should pass through pipeline after successful handshake");
}
} finally {
for (final ByteBuf additionalMessage : additionalMessages) {
additionalMessage.release();
}
}
}
@Test
void handleCompleteHandshakeFailureBufferedReads()
throws ShortBufferException, NoSuchAlgorithmException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = (byte) ThreadLocalRandom.current().nextInt(Device.MAXIMUM_DEVICE_ID);
final ECKeyPair clientKeyPair = Curve.generateKeyPair();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final HandshakeState clientHandshakeState = exchangeClientEphemeralAndServerStaticMessages(clientKeyPair);
sendClientStaticKey(clientHandshakeState, accountIdentifier, deviceId);
final ByteBuf[] additionalMessages = new ByteBuf[4];
final CipherState senderState = clientHandshakeState.split().getSender();
try {
for (int i = 0; i < additionalMessages.length; i++) {
final byte[] contentBytes = new byte[32];
ThreadLocalRandom.current().nextBytes(contentBytes);
// Copy the "plaintext" portion of the content bytes for future assertions
additionalMessages[i] = Unpooled.buffer(16).writeBytes(contentBytes, 0, 16);
// Overwrite the first 16 bytes of a random "plaintext" with a ciphertext and the second 16 bytes with the AEAD
// tag
senderState.encryptWithAd(null, contentBytes, 0, contentBytes, 0, 16);
assertTrue(embeddedChannel.writeOneInbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes))).await().isSuccess());
}
findPublicKeyFuture.complete(Optional.empty());
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseXXHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseStreamHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
assertTrue(embeddedChannel.inboundMessages().isEmpty(),
"Buffered messages should not pass through pipeline after failed handshake");
} finally {
for (final ByteBuf additionalMessage : additionalMessages) {
additionalMessage.release();
}
}
}
private HandshakeState exchangeClientEphemeralAndServerStaticMessages(final ECKeyPair clientKeyPair)
throws NoSuchAlgorithmException, ShortBufferException, BadPaddingException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState =
new HandshakeState(NoiseXXHandshakeHandler.NOISE_PROTOCOL_NAME, HandshakeState.INITIATOR);
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
clientHandshakeState.start();
{
final byte[] ephemeralKeyMessageBytes = new byte[32];
clientHandshakeState.writeMessage(ephemeralKeyMessageBytes, 0, null, 0, 0);
final BinaryWebSocketFrame ephemeralKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(ephemeralKeyMessageBytes));
assertTrue(embeddedChannel.writeOneInbound(ephemeralKeyMessageFrame).await().isSuccess());
assertEquals(0, ephemeralKeyMessageFrame.refCnt());
}
{
assertEquals(1, embeddedChannel.outboundMessages().size());
final BinaryWebSocketFrame serverStaticKeyMessageFrame =
(BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll();
@SuppressWarnings("DataFlowIssue") final byte[] serverStaticKeyMessageBytes =
new byte[serverStaticKeyMessageFrame.content().readableBytes()];
serverStaticKeyMessageFrame.content().readBytes(serverStaticKeyMessageBytes);
final byte[] serverPublicKeySignature = new byte[64];
final int payloadLength =
clientHandshakeState.readMessage(serverStaticKeyMessageBytes, 0, serverStaticKeyMessageBytes.length, serverPublicKeySignature, 0);
assertEquals(serverPublicKeySignature.length, payloadLength);
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertTrue(getRootPublicKey().verifySignature(serverPublicKey, serverPublicKeySignature));
}
return clientHandshakeState;
}
private void sendClientStaticKey(final HandshakeState handshakeState, final UUID accountIdentifier, final byte deviceId)
throws ShortBufferException, InterruptedException {
final ByteBuffer clientIdentityPayloadBuffer = ByteBuffer.allocate(17);
clientIdentityPayloadBuffer.putLong(accountIdentifier.getMostSignificantBits());
clientIdentityPayloadBuffer.putLong(accountIdentifier.getLeastSignificantBits());
clientIdentityPayloadBuffer.put(deviceId);
clientIdentityPayloadBuffer.flip();
final byte[] clientStaticKeyMessageBytes = new byte[81];
final int messageLength =
handshakeState.writeMessage(clientStaticKeyMessageBytes, 0, clientIdentityPayloadBuffer.array(), 0, clientIdentityPayloadBuffer.remaining());
assertEquals(clientStaticKeyMessageBytes.length, messageLength);
final BinaryWebSocketFrame clientStaticKeyMessageFrame =
new BinaryWebSocketFrame(Unpooled.wrappedBuffer(clientStaticKeyMessageBytes));
assertTrue(getEmbeddedChannel().writeOneInbound(clientStaticKeyMessageFrame).await().isSuccess());
assertEquals(0, clientStaticKeyMessageFrame.refCnt());
}
}

View File

@@ -0,0 +1,24 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
class OutboundCloseWebSocketFrameHandler extends ChannelOutboundHandlerAdapter {
private final WebSocketCloseListener webSocketCloseListener;
OutboundCloseWebSocketFrameHandler(final WebSocketCloseListener webSocketCloseListener) {
this.webSocketCloseListener = webSocketCloseListener;
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
webSocketCloseListener.handleWebSocketClosedByClient(closeWebSocketFrame.statusCode());
}
super.write(context, message, promise);
}
}

View File

@@ -0,0 +1,72 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
private EmbeddedChannel embeddedChannel;
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler());
}
@ParameterizedTest
@MethodSource
void allowWebSocketFrame(final WebSocketFrame frame) {
embeddedChannel.writeOneInbound(frame);
try {
assertEquals(frame, embeddedChannel.inboundMessages().poll());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(1, frame.refCnt());
} finally {
frame.release();
}
}
private static List<WebSocketFrame> allowWebSocketFrame() {
return List.of(
new BinaryWebSocketFrame(),
new CloseWebSocketFrame(),
new ContinuationWebSocketFrame(),
new PingWebSocketFrame(),
new PongWebSocketFrame());
}
@Test
void rejectTextFrame() {
final TextWebSocketFrame textFrame = new TextWebSocketFrame();
embeddedChannel.writeOneInbound(textFrame);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, textFrame.refCnt());
}
@Test
void rejectNonWebSocketFrame() {
final ByteBuf bytes = Unpooled.buffer(0);
embeddedChannel.writeOneInbound(bytes);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, bytes.refCnt());
}
}

View File

@@ -0,0 +1,18 @@
package org.whispersystems.textsecuregcm.grpc.net;
interface WebSocketCloseListener {
WebSocketCloseListener NOOP_LISTENER = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
}
};
void handleWebSocketClosedByClient(int statusCode);
void handleWebSocketClosedByServer(int statusCode);
}

View File

@@ -0,0 +1,70 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.UUID;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import javax.annotation.Nullable;
class WebSocketNoiseTunnelClient implements AutoCloseable {
private final ServerBootstrap serverBootstrap;
private Channel serverChannel;
static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
public WebSocketNoiseTunnelClient(final SocketAddress remoteServerAddress,
final URI websocketUri,
final boolean authenticated,
final ECKeyPair ecKeyPair,
final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final X509Certificate trustedServerCertificate,
final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) {
this.serverBootstrap = new ServerBootstrap()
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
.channel(LocalServerChannel.class)
.group(eventLoopGroup)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate,
websocketUri,
authenticated,
ecKeyPair,
rootPublicKey,
accountIdentifier,
deviceId,
remoteServerAddress,
webSocketCloseListener));
}
});
}
LocalAddress getLocalAddress() {
return (LocalAddress) serverChannel.localAddress();
}
WebSocketNoiseTunnelClient start() throws InterruptedException {
serverChannel = serverBootstrap.bind().await().channel();
return this;
}
@Override
public void close() throws InterruptedException {
serverChannel.close().await();
}
}

View File

@@ -0,0 +1,486 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.AuthenticationTypeGrpc;
import org.signal.chat.rpc.GetAuthenticatedRequest;
import org.signal.chat.rpc.GetAuthenticatedResponse;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor;
private static X509Certificate serverTlsCertificate;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair rootKeyPair;
private ECKeyPair clientKeyPair;
private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer;
private WebsocketNoiseTunnelServer websocketNoiseTunnelServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with:
//
// ```shell
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
// ```
private static final String SERVER_CERTIFICATE = """
-----BEGIN CERTIFICATE-----
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
iOr9sHiO8Rn2u0xRKgU5Ig==
-----END CERTIFICATE-----
""";
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
private static final String SERVER_PRIVATE_KEY = """
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
""";
@BeforeAll
static void setUpBeforeAll() throws CertificateException {
nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newSingleThreadExecutor();
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
}
@BeforeEach
void setUp() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException, InterruptedException {
final PrivateKey serverTlsPrivateKey;
{
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
serverTlsPrivateKey =
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
}
rootKeyPair = Curve.generateKeyPair();
clientKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(true));
}
};
authenticatedGrpcServer.start();
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(false));
}
};
anonymousGrpcServer.start();
websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
new X509Certificate[] { serverTlsCertificate },
serverTlsPrivateKey,
nioEventLoopGroup,
delegatedTaskExecutor,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress);
websocketNoiseTunnelServer.start();
}
@AfterEach
void tearDown() throws InterruptedException {
websocketNoiseTunnelServer.stop();
authenticatedGrpcServer.stop();
anonymousGrpcServer.stop();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
void connectAuthenticated() throws InterruptedException {
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient()) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
assertTrue(response.getAuthenticated());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(Curve.generateKeyPair().getPublicKey())));
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAuthenticatedClient(webSocketCloseListener)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByServer(ApplicationWebSocketCloseReason.CLIENT_AUTHENTICATION_ERROR.getStatusCode());
}
@Test
void connectAuthenticatedToAnonymousService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/anonymous"),
true,
clientKeyPair,
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymous() throws InterruptedException {
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAnonymousClient()) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
assertFalse(response.getAuthenticated());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAnonymousBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
// Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
@Test
void connectAnonymousToAuthenticatedService() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/authenticated"),
false,
null,
rootKeyPair.getPublicKey(),
null,
(byte) 0,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
.start()) {
final ManagedChannel channel = buildManagedChannel(websocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
}
verify(webSocketCloseListener).handleWebSocketClosedByClient(ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode());
}
private ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
return NettyChannelBuilder.forAddress(localAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(defaultEventLoopGroup)
.usePlaintext()
.build();
}
@Test
void rejectIllegalRequests() throws Exception {
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
final TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
final URI incorrectUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
.uri(authenticatedUri)
.PUT(HttpRequest.BodyPublishers.ofString("test"))
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"Non-GET requests should not be allowed");
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(authenticatedUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests without upgrade headers should not be allowed");
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(incorrectUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests to unrecognized URIs should not be allowed");
}
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
throws InterruptedException {
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey());
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
rootPublicKey,
ACCOUNT_IDENTIFIER,
DEVICE_ID,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
.start();
}
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey());
}
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
false,
null,
rootPublicKey,
null,
(byte) 0,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
.start();
}
}

View File

@@ -0,0 +1,104 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
private EmbeddedChannel embeddedChannel;
private static final String AUTHENTICATED_PATH = "/authenticated";
private static final String ANONYMOUS_PATH = "/anonymous";
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH));
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleValidRequest(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.GET, path,
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
try {
embeddedChannel.writeOneInbound(request);
assertEquals(1, request.refCnt());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertEquals(request, embeddedChannel.inboundMessages().poll());
} finally {
request.release();
}
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleUpgradeRequired(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders());
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED);
}
@Test
void handleBadPath() {
final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect",
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.NOT_FOUND);
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleMethodNotAllowed(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path,
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED);
}
private void assertHttpResponse(final HttpResponseStatus expectedStatus) {
assertEquals(1, embeddedChannel.outboundMessages().size());
final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll());
//noinspection DataFlowIssue
assertEquals(expectedStatus, response.status());
}
private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) {
return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
method,
path,
Unpooled.buffer(0),
headers,
new DefaultHttpHeaders(true));
}
}

View File

@@ -0,0 +1,91 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private EmbeddedChannel embeddedChannel;
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new EmbeddedChannel(
new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
userEventRecordingHandler);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
}