Add a plaintext mode to the Noise-over-WebSocket server for local testing

This commit is contained in:
Jon Chambers
2024-05-21 17:23:22 -04:00
committed by Jon Chambers
parent 9e36cabef0
commit 9a2bfe1180
4 changed files with 106 additions and 34 deletions

View File

@@ -28,7 +28,8 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey;
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final X509Certificate trustedServerCertificate;
private final boolean useTls;
@Nullable private final X509Certificate trustedServerCertificate;
private final URI websocketUri;
private final boolean authenticated;
@Nullable private final ECKeyPair ecKeyPair;
@@ -44,7 +45,8 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler(
final X509Certificate trustedServerCertificate,
final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
final URI websocketUri,
final boolean authenticated,
@Nullable final ECKeyPair ecKeyPair,
@@ -55,6 +57,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) {
this.useTls = useTls;
this.trustedServerCertificate = trustedServerCertificate;
this.websocketUri = websocketUri;
this.authenticated = authenticated;
@@ -75,12 +78,17 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws SSLException {
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (trustedServerCertificate != null) {
sslContextBuilder.trustManager(trustedServerCertificate);
}
channel.pipeline().addLast(sslContextBuilder.build().newHandler(channel.alloc()));
}
channel.pipeline()
.addLast(SslContextBuilder
.forClient()
.trustManager(trustedServerCertificate)
.build()
.newHandler(channel.alloc()))
.addLast(new HttpClientCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we

View File

@@ -32,7 +32,8 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final X509Certificate trustedServerCertificate,
final boolean useTls,
@Nullable final X509Certificate trustedServerCertificate,
final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) {
@@ -43,7 +44,8 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(trustedServerCertificate,
localChannel.pipeline().addLast(new EstablishRemoteConnectionHandler(useTls,
trustedServerCertificate,
websocketUri,
authenticated,
ecKeyPair,

View File

@@ -89,7 +89,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer;
private WebsocketNoiseTunnelServer websocketNoiseTunnelServer;
private WebsocketNoiseTunnelServer tlsWebsocketNoiseTunnelServer;
private WebsocketNoiseTunnelServer plaintextWebsocketNoiseTunnelServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
@@ -184,7 +185,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer.start();
websocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
tlsWebsocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
new X509Certificate[] { serverTlsCertificate },
serverTlsPrivateKey,
nioEventLoopGroup,
@@ -197,12 +198,28 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
websocketNoiseTunnelServer.start();
tlsWebsocketNoiseTunnelServer.start();
plaintextWebsocketNoiseTunnelServer = new WebsocketNoiseTunnelServer(0,
null,
null,
nioEventLoopGroup,
delegatedTaskExecutor,
clientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
plaintextWebsocketNoiseTunnelServer.start();
}
@AfterEach
void tearDown() throws InterruptedException {
websocketNoiseTunnelServer.stop();
tlsWebsocketNoiseTunnelServer.stop();
plaintextWebsocketNoiseTunnelServer.stop();
authenticatedGrpcServer.stop();
anonymousGrpcServer.stop();
}
@@ -234,6 +251,36 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
@Test
void connectAuthenticatedPlaintext() throws InterruptedException {
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
nioEventLoopGroup,
WebSocketCloseListener.NOOP_LISTENER)
.start()) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException {
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
@@ -313,7 +360,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(),
tlsWebsocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/anonymous"),
true,
clientKeyPair,
@@ -321,6 +368,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -386,7 +434,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final WebSocketCloseListener webSocketCloseListener = mock(WebSocketCloseListener.class);
try (final WebSocketNoiseTunnelClient websocketNoiseTunnelClient = new WebSocketNoiseTunnelClient(
websocketNoiseTunnelServer.getLocalAddress(),
tlsWebsocketNoiseTunnelServer.getLocalAddress(),
URI.create("wss://localhost/authenticated"),
false,
null,
@@ -394,6 +442,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null,
(byte) 0,
new DefaultHttpHeaders(),
true,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -438,10 +487,10 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
new URI("https", null, "localhost", tlsWebsocketNoiseTunnelServer.getLocalAddress().getPort(), "/authenticated", null, null);
final URI incorrectUri =
new URI("https", null, "localhost", websocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
new URI("https", null, "localhost", tlsWebsocketNoiseTunnelServer.getLocalAddress().getPort(), "/incorrect", null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
@@ -561,7 +610,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
return new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
true,
clientKeyPair,
@@ -569,6 +618,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
ACCOUNT_IDENTIFIER,
DEVICE_ID,
headers,
true,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -583,7 +633,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
return new WebSocketNoiseTunnelClient(tlsWebsocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
false,
null,
@@ -591,6 +641,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
null,
(byte) 0,
headers,
true,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)