() {
- @Override
- protected void initChannel(final LocalChannel localChannel) {
- localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel()));
- }
- })
- .connect()
- .addListener((ChannelFutureListener) localChannelFuture -> {
- if (localChannelFuture.isSuccess()) {
- grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
- remoteChannelContext.channel(),
- authenticatedDevice);
-
- // Close the local connection if the remote channel closes and vice versa
- remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
- localChannelFuture.channel().closeFuture().addListener(closeFuture ->
- remoteChannelContext.channel()
- .write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
- .addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
-
- remoteChannelContext.pipeline()
- .addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
-
- // Flush any buffered reads we accumulated while waiting to open the connection
- pendingReads.forEach(remoteChannelContext::fireChannelRead);
- pendingReads.clear();
-
- remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
- } else {
- log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
- remoteChannelContext.close();
- }
- });
- }
-
- remoteChannelContext.fireUserEventTriggered(event);
- }
-
- @Override
- public void handlerRemoved(final ChannelHandlerContext context) {
- pendingReads.forEach(ReferenceCountUtil::release);
- pendingReads.clear();
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java
deleted file mode 100644
index 0343b8e9a..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java
+++ /dev/null
@@ -1,6 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-public enum FramingType {
- NOISE_DIRECT,
- WEBSOCKET
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java
deleted file mode 100644
index ad12da9fe..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java
+++ /dev/null
@@ -1,268 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import com.google.common.annotations.VisibleForTesting;
-import io.grpc.Grpc;
-import io.grpc.ServerCall;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.local.LocalChannel;
-import io.netty.util.AttributeKey;
-import java.net.InetAddress;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Optional;
-import java.util.concurrent.ConcurrentHashMap;
-import javax.annotation.Nullable;
-import org.apache.commons.lang3.StringUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
-import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
-import org.whispersystems.textsecuregcm.util.ClosableEpoch;
-
-/**
- * A client connection manager associates a local connection to a local gRPC server with a remote connection through a
- * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
- * identity of the device that opened the connection (for non-anonymous connections). It can also close connections
- * associated with a given device if that device's credentials have changed and clients must reauthenticate.
- *
- * In general, all {@link ServerCall}s must have a local address that in turn should be resolvable to
- * a remote channel, which must have associated request attributes and authentication status. It is possible
- * that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
- * narrow window between a server call being created and the start of call execution, in which case accessor methods
- * in this class will throw a {@link ChannelNotFoundException}.
- *
- * A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
- * identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
- * Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
- * be called from any application code.
- */
-public class GrpcClientConnectionManager {
-
- private final Map remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
- private final Map> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
-
- @VisibleForTesting
- static final AttributeKey AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
- AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
-
- @VisibleForTesting
- public static final AttributeKey REQUEST_ATTRIBUTES_KEY =
- AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
-
- @VisibleForTesting
- static final AttributeKey EPOCH_ATTRIBUTE_KEY =
- AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
-
- private static final OutboundCloseErrorMessage SERVER_CLOSED =
- new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
-
- private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
-
- /**
- * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
- * (i.e. unauthenticated), the returned value will be empty.
- *
- * @param serverCall the gRPC server call for which to find an authenticated device
- *
- * @return the authenticated device associated with the given local address, if any
- *
- * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
- * generally indicates that the channel has closed while request processing is still in progress
- */
- public Optional getAuthenticatedDevice(final ServerCall, ?> serverCall)
- throws ChannelNotFoundException {
-
- return getAuthenticatedDevice(getRemoteChannel(serverCall));
- }
-
- @VisibleForTesting
- Optional getAuthenticatedDevice(final Channel remoteChannel) {
- return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
- }
-
- /**
- * Returns the request attributes associated with the given server call.
- *
- * @param serverCall the gRPC server call for which to retrieve request attributes
- *
- * @return the request attributes associated with the given server call
- *
- * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
- * generally indicates that the channel has closed while request processing is still in progress
- */
- public RequestAttributes getRequestAttributes(final ServerCall, ?> serverCall) throws ChannelNotFoundException {
- return getRequestAttributes(getRemoteChannel(serverCall));
- }
-
- @VisibleForTesting
- RequestAttributes getRequestAttributes(final Channel remoteChannel) {
- final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
-
- if (requestAttributes == null) {
- throw new IllegalStateException("Channel does not have request attributes");
- }
-
- return requestAttributes;
- }
-
- /**
- * Handles the start of a server call, incrementing the active call count for the remote channel associated with the
- * given server call.
- *
- * @param serverCall the server call to start
- *
- * @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
- * underlying channel is closing
- */
- public boolean handleServerCallStart(final ServerCall, ?> serverCall) {
- try {
- return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
- } catch (final ChannelNotFoundException e) {
- // This would only happen if the channel had already closed, which is certainly possible. In this case, the call
- // should certainly not proceed.
- return false;
- }
- }
-
- /**
- * Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
- * associated with the given server call.
- *
- * @param serverCall the server call to complete
- */
- public void handleServerCallComplete(final ServerCall, ?> serverCall) {
- try {
- getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
- } catch (final ChannelNotFoundException ignored) {
- // In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
- }
- }
-
- /**
- * Closes any client connections to this host associated with the given authenticated device.
- *
- * @param authenticatedDevice the authenticated device for which to close connections
- */
- public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
- // Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid
- // concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself
- // from the list while we're still iterating, resulting in a `ConcurrentModificationException`.
- final List channelsToClose =
- new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
-
- channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
- }
-
- private static void closeRemoteChannel(final Channel channel) {
- channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
- }
-
- @VisibleForTesting
- @Nullable List getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
- return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
- }
-
- private Channel getRemoteChannel(final ServerCall, ?> serverCall) throws ChannelNotFoundException {
- return getRemoteChannel(getLocalAddress(serverCall));
- }
-
- @VisibleForTesting
- Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
- final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
-
- if (remoteChannel == null) {
- throw new ChannelNotFoundException();
- }
-
- return remoteChannelsByLocalAddress.get(localAddress);
- }
-
- private static LocalAddress getLocalAddress(final ServerCall, ?> serverCall) {
- // In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
- // The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
- // a proxied Noise channel.
- if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
- throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
- }
-
- return localAddress;
- }
-
- /**
- * Handles receipt of a handshake message and associates attributes and headers from the handshake
- * request with the channel via which the handshake took place.
- *
- * @param channel the channel where the handshake was initiated
- * @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
- * @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
- * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
- * {@code null}
- */
- public static void handleHandshakeInitiated(final Channel channel,
- final InetAddress preferredRemoteAddress,
- @Nullable final String userAgentHeader,
- @Nullable final String acceptLanguageHeader) {
-
- @Nullable List acceptLanguages = Collections.emptyList();
-
- if (StringUtils.isNotBlank(acceptLanguageHeader)) {
- try {
- acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
- } catch (final IllegalArgumentException e) {
- log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
- }
- }
-
- channel.attr(REQUEST_ATTRIBUTES_KEY)
- .set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
- }
-
- /**
- * Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
- *
- * @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
- * @param remoteChannel the channel from the remote client to the Noise tunnel
- * @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
- */
- void handleConnectionEstablished(final LocalChannel localChannel,
- final Channel remoteChannel,
- @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) {
-
- maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
- remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
-
- remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
- .set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
-
- remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
-
- getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
- remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
- final List channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
- channels.add(remoteChannel);
-
- return channels;
- }));
-
- remoteChannel.closeFuture().addListener(closeFuture -> {
- remoteChannelsByLocalAddress.remove(localChannel.localAddress());
-
- getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
- remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
- if (existingChannelList == null) {
- return null;
- }
-
- existingChannelList.remove(remoteChannel);
-
- return existingChannelList.isEmpty() ? null : existingChannelList;
- }));
- });
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java
deleted file mode 100644
index 2bd807967..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java
+++ /dev/null
@@ -1,32 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.codec.haproxy.HAProxyMessage;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has
- * either handled a proxy protocol message or determined that no such message is coming.
- */
-public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {
-
- private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class);
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- if (message instanceof HAProxyMessage haProxyMessage) {
- // Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still
- // need to clear it from the pipeline to avoid confusing the TLS machinery, though.
- log.debug("Discarding HAProxy message: {}", haProxyMessage);
- haProxyMessage.release();
- } else {
- super.channelRead(context, message);
- }
-
- // Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first
- // message, all others will just be "normal" messages, and our work here is done.
- context.pipeline().remove(this);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java
deleted file mode 100644
index cfd0205fc..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java
+++ /dev/null
@@ -1,21 +0,0 @@
-/*
- * Copyright 2024 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-public enum HandshakePattern {
- NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
- IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
-
- private final String protocol;
-
- public String protocol() {
- return protocol;
- }
-
-
- HandshakePattern(String protocol) {
- this.protocol = protocol;
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java
deleted file mode 100644
index 5888174c4..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java
+++ /dev/null
@@ -1,16 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.dropwizard.lifecycle.Managed;
-import io.netty.channel.DefaultEventLoopGroup;
-
-/**
- * A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
- * Dropwizard to manage the lifecycle of the event loop group.
- */
-public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed {
-
- @Override
- public void stop() throws InterruptedException {
- this.shutdownGracefully().await();
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java
new file mode 100644
index 000000000..87ac886da
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java
@@ -0,0 +1,29 @@
+package org.whispersystems.textsecuregcm.grpc.net;
+
+import io.dropwizard.lifecycle.Managed;
+import io.grpc.Server;
+
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+public class ManagedGrpcServer implements Managed {
+ private final Server server;
+
+ public ManagedGrpcServer(Server server) {
+ this.server = server;
+ }
+
+ @Override
+ public void start() throws IOException {
+ server.start();
+ }
+
+ @Override
+ public void stop() {
+ try {
+ server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
+ } catch (final InterruptedException e) {
+ server.shutdownNow();
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java
deleted file mode 100644
index 21cd3d846..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java
+++ /dev/null
@@ -1,49 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.dropwizard.lifecycle.Managed;
-import io.grpc.Server;
-import io.grpc.ServerBuilder;
-import io.grpc.netty.NettyServerBuilder;
-import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.local.LocalServerChannel;
-import java.io.IOException;
-import java.util.concurrent.TimeUnit;
-
-/**
- * A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress}
- * and whose lifecycle is managed by Dropwizard via the {@link Managed} interface.
- */
-public abstract class ManagedLocalGrpcServer implements Managed {
-
- private final Server server;
-
- public ManagedLocalGrpcServer(final LocalAddress localAddress,
- final DefaultEventLoopGroup eventLoopGroup) {
-
- final ServerBuilder> serverBuilder = NettyServerBuilder.forAddress(localAddress)
- .channelType(LocalServerChannel.class)
- .bossEventLoopGroup(eventLoopGroup)
- .workerEventLoopGroup(eventLoopGroup);
-
- configureServer(serverBuilder);
-
- server = serverBuilder.build();
- }
-
- protected abstract void configureServer(final ServerBuilder> serverBuilder);
-
- @Override
- public void start() throws IOException {
- server.start();
- }
-
- @Override
- public void stop() {
- try {
- server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
- } catch (final InterruptedException e) {
- server.shutdownNow();
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java
deleted file mode 100644
index 55bc979c7..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java
+++ /dev/null
@@ -1,13 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import org.whispersystems.textsecuregcm.util.NoStackTraceException;
-
-/**
- * Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
- * format or a general encryption error).
- */
-public class NoiseException extends NoStackTraceException {
- public NoiseException(final String message) {
- super(message);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java
deleted file mode 100644
index 42e871e35..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Copyright 2024 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import com.southernstorm.noise.protocol.CipherState;
-import com.southernstorm.noise.protocol.CipherStatePair;
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.util.ReferenceCountUtil;
-import io.netty.util.concurrent.PromiseCombiner;
-import javax.crypto.BadPaddingException;
-import javax.crypto.ShortBufferException;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
- * messages
- */
-public class NoiseHandler extends ChannelDuplexHandler {
-
- private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
- private final CipherStatePair cipherStatePair;
-
- NoiseHandler(CipherStatePair cipherStatePair) {
- this.cipherStatePair = cipherStatePair;
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- try {
- if (message instanceof ByteBuf frame) {
- if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
- throw new NoiseException("Invalid noise message length " + frame.readableBytes());
- }
- // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
- // We'll need to copy it to a heap buffer.
- handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
- } else {
- // Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
- throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
- }
- } finally {
- ReferenceCountUtil.release(message);
- }
- }
-
-
- private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
- throws ShortBufferException, BadPaddingException {
- final CipherState cipherState = cipherStatePair.getReceiver();
- // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
- final int plaintextLength = cipherState.decryptWithAd(null,
- frameBytes, 0,
- frameBytes, 0,
- frameBytes.length);
-
- // Forward the decrypted plaintext along
- context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
- }
-
- @Override
- public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
- throws Exception {
- if (message instanceof ByteBuf byteBuf) {
- try {
- // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
- final CipherState cipherState = cipherStatePair.getSender();
-
- // Server message might not fit in a single noise packet, break it up into as many chunks as we need
- final PromiseCombiner pc = new PromiseCombiner(context.executor());
- while (byteBuf.isReadable()) {
- final ByteBuf plaintext = byteBuf.readSlice(Math.min(
- // need room for a 16-byte AEAD tag
- Noise.MAX_PACKET_LEN - 16,
- byteBuf.readableBytes()));
-
- final int plaintextLength = plaintext.readableBytes();
-
- // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
- // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
- // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
- final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
- plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
-
- // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
- cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
-
- pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
- }
- pc.finish(promise);
- } finally {
- ReferenceCountUtil.release(byteBuf);
- }
- } else {
- if (!(message instanceof OutboundCloseErrorMessage)) {
- // Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
- // that get issued in response to exceptions)
- log.warn("Unexpected object in pipeline: {}", message);
- }
- context.write(message, promise);
- }
- }
-
- @Override
- public void handlerRemoved(ChannelHandlerContext var1) {
- if (cipherStatePair != null) {
- cipherStatePair.destroy();
- }
- }
-
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java
deleted file mode 100644
index b7a509728..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java
+++ /dev/null
@@ -1,14 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import org.whispersystems.textsecuregcm.util.NoStackTraceException;
-
-/**
- * Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
- * a general encryption error).
- */
-public class NoiseHandshakeException extends NoStackTraceException {
-
- public NoiseHandshakeException(final String message) {
- super(message);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java
deleted file mode 100644
index 80d7f6105..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Copyright 2024 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufInputStream;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.util.ReferenceCountUtil;
-import java.io.IOException;
-import java.net.InetAddress;
-import java.security.MessageDigest;
-import java.util.Optional;
-import java.util.UUID;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
-import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
-import org.whispersystems.textsecuregcm.util.ExceptionUtils;
-import org.whispersystems.textsecuregcm.util.UUIDUtil;
-
-/**
- * Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
- * encrypt/decrypt subsequent data frames
- *
- * The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
- * handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
- * currently supports two types of handshakes.
- *
- * The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
- * handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
- * the handshake message must match the server's stored public key for the device identified by the provided ACI and
- * deviceId.
- *
- * The second are NK handshakes which are anonymous.
- *
- * Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
- * to begin processing the request without an initial message delay (fast open).
- *
- * Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
- * this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
- * handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
- * frames.
- */
-public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
-
- private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
- .setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
- .build().toByteArray();
- private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
- .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
- .build().toByteArray();
-
- // We might get additional messages while we're waiting to process a handshake, so keep track of where we are
- private boolean receivedHandshakeInit = false;
-
- private final ClientPublicKeysManager clientPublicKeysManager;
- private final ECKeyPair ecKeyPair;
-
- public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
- this.clientPublicKeysManager = clientPublicKeysManager;
- this.ecKeyPair = ecKeyPair;
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- try {
- if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
- // Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
- throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
- }
- if (receivedHandshakeInit) {
- throw new NoiseHandshakeException("Should not receive messages until handshake complete");
- }
- receivedHandshakeInit = true;
-
- if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
- throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
- }
-
- // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
- // We'll need to copy it to a heap buffer
- handleInboundHandshake(context,
- handshakeInit.getRemoteAddress(),
- handshakeInit.getHandshakePattern(),
- ByteBufUtil.getBytes(handshakeInit.content()));
- } finally {
- ReferenceCountUtil.release(message);
- }
- }
-
- private void handleInboundHandshake(
- final ChannelHandlerContext context,
- final InetAddress remoteAddress,
- final HandshakePattern handshakePattern,
- final byte[] frameBytes) throws NoiseHandshakeException {
- final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
- final ByteBuf payload = handshakeHelper.read(frameBytes);
-
- // Parse the handshake message
- final NoiseTunnelProtos.HandshakeInit handshakeInit;
- try {
- handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
- } catch (IOException e) {
- throw new NoiseHandshakeException("Failed to parse handshake message");
- }
-
- switch (handshakePattern) {
- case NK -> {
- if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
- throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
- }
- handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
- }
- case IK -> {
- final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
- .orElseThrow(() -> new IllegalStateException("No remote public key"));
- final UUID accountIdentifier = aci(handshakeInit);
- final byte deviceId = deviceId(handshakeInit);
- clientPublicKeysManager
- .findPublicKey(accountIdentifier, deviceId)
- .whenCompleteAsync((storedPublicKey, throwable) -> {
- if (throwable != null) {
- context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
- return;
- }
- final boolean valid = storedPublicKey
- .map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
- .orElse(false);
- if (!valid) {
- // Write a handshake response indicating that the client used the wrong public key
- final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
- context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
- .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
-
- context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
- return;
- }
- handleAuthenticated(context,
- handshakeHelper, remoteAddress, handshakeInit,
- Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
- }, context.executor());
- }
- };
- }
-
- private void handleAuthenticated(final ChannelHandlerContext context,
- final NoiseHandshakeHelper handshakeHelper,
- final InetAddress remoteAddress,
- final NoiseTunnelProtos.HandshakeInit handshakeInit,
- final Optional maybeAuthenticatedDevice) {
- context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
- maybeAuthenticatedDevice,
- remoteAddress,
- handshakeInit.getUserAgent(),
- handshakeInit.getAcceptLanguage()));
-
- // Now that we've authenticated, write the handshake response
- final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
- context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
- .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
-
- // The handshake is complete. We can start intercepting read/write for noise encryption/decryption
- // Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
- // request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
- // decrypted the fastOpen request
- context.pipeline()
- .addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
- context.pipeline().remove(NoiseHandshakeHandler.class);
- if (!handshakeInit.getFastOpenRequest().isEmpty()) {
- // The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
- // encrypt the response when the server writes back through us
- context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
- }
- }
-
- private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
- try {
- return UUIDUtil.fromByteString(handshakePayload.getAci());
- } catch (IllegalArgumentException e) {
- throw new NoiseHandshakeException("Could not parse aci");
- }
- }
-
- private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
- if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
- throw new NoiseHandshakeException("Invalid deviceId");
- }
- return (byte) handshakePayload.getDeviceId();
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java
deleted file mode 100644
index e47120fb4..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Copyright 2024 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import com.southernstorm.noise.protocol.HandshakeState;
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import java.security.NoSuchAlgorithmException;
-import java.util.Optional;
-import javax.crypto.BadPaddingException;
-import javax.crypto.ShortBufferException;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-
-/**
- * Helper for the responder of a 2-message handshake with a pre-shared responder static key
- */
-class NoiseHandshakeHelper {
-
- private final static int AEAD_TAG_LENGTH = 16;
- private final static int KEY_LENGTH = 32;
-
- private final HandshakePattern handshakePattern;
- private final ECKeyPair serverStaticKeyPair;
- private final HandshakeState handshakeState;
-
- NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) {
- this.handshakePattern = handshakePattern;
- this.serverStaticKeyPair = serverStaticKeyPair;
- try {
- this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER);
- } catch (final NoSuchAlgorithmException e) {
- throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e);
- }
- }
-
- /**
- * Get the length of the initiator's keys
- *
- * @return length of the handshake message sent by the remote party (the initiator) not including the payload
- */
- private int initiatorHandshakeMessageKeyLength() {
- return switch (handshakePattern) {
- // ephemeral key, static key (encrypted), AEAD tag for static key
- case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH;
- // ephemeral key only
- case NK -> KEY_LENGTH;
- };
- }
-
- HandshakeState getHandshakeState() {
- return this.handshakeState;
- }
-
- ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException {
- if (handshakeState.getAction() != HandshakeState.NO_ACTION) {
- throw new NoiseHandshakeException("Cannot send more data before handshake is complete");
- }
-
- // Length for an empty payload
- final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH;
- if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) {
- throw new NoiseHandshakeException("Unexpected ephemeral key message length");
- }
-
- final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH;
-
- // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
- // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
- // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
- // key), so we just set the private key here.
- handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0);
- handshakeState.start();
-
- int payloadBytesRead;
-
- try {
- payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length,
- remoteHandshakeMessage, 0);
- } catch (final ShortBufferException e) {
- // This should never happen since we're checking the length of the frame up front
- throw new NoiseHandshakeException("Unexpected client payload");
- } catch (final BadPaddingException e) {
- // We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key
- // or payload
- throw new NoiseHandshakeException("Invalid keys or payload");
- }
- if (payloadBytesRead != payloadLength) {
- throw new NoiseHandshakeException(
- "Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead);
- }
- return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead);
- }
-
- byte[] write(byte[] payload) {
- if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) {
- throw new IllegalStateException("Cannot send data before handshake is complete");
- }
-
- // Currently only support handshake patterns where the server static key is known
- // Send our ephemeral key and the response to the initiator with the encrypted payload
- final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH];
- try {
- int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length);
- if (written != response.length) {
- throw new IllegalStateException("Unexpected handshake response length");
- }
- return response;
- } catch (final ShortBufferException e) {
- // This should never happen for messages of known length that we control
- throw new IllegalStateException("Key material buffer was too short for message", e);
- }
- }
-
- Optional remotePublicKey() {
- return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> {
- final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
- handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
- return publicKeyFromClient;
- });
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java
deleted file mode 100644
index d5368f366..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java
+++ /dev/null
@@ -1,33 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.DefaultByteBufHolder;
-import java.net.InetAddress;
-
-/**
- * A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
- * and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
- */
-public class NoiseHandshakeInit extends DefaultByteBufHolder {
-
- private final InetAddress remoteAddress;
- private final HandshakePattern handshakePattern;
-
- public NoiseHandshakeInit(
- final InetAddress remoteAddress,
- final HandshakePattern handshakePattern,
- final ByteBuf initiatorHandshakeMessage) {
- super(initiatorHandshakeMessage);
- this.remoteAddress = remoteAddress;
- this.handshakePattern = handshakePattern;
- }
-
- public InetAddress getRemoteAddress() {
- return remoteAddress;
- }
-
- public HandshakePattern getHandshakePattern() {
- return handshakePattern;
- }
-
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java
deleted file mode 100644
index 45344248d..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java
+++ /dev/null
@@ -1,22 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import java.net.InetAddress;
-import java.util.Optional;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-
-/**
- * An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is
- * connecting anonymously, the identity is empty, otherwise it will be present and already authenticated.
- *
- * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
- * type that performs authentication
- * @param remoteAddress the remote address of the connecting client
- * @param userAgent the client supplied userAgent
- * @param acceptLanguage the client supplied acceptLanguage
- */
-public record NoiseIdentityDeterminedEvent(
- Optional authenticatedDevice,
- InetAddress remoteAddress,
- String userAgent,
- String acceptLanguage) {
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java
deleted file mode 100644
index 3ef4f84f1..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-/**
- * An error written to the outbound pipeline that indicates the connection should be closed
- */
-public record OutboundCloseErrorMessage(Code code, String message) {
- public enum Code {
-
- /**
- * The server decided to close the connection. This could be because the server is going away, or it could be
- * because the credentials for the connected client have been updated.
- */
- SERVER_CLOSED,
-
- /**
- * There was a noise decryption error after the noise session was established
- */
- NOISE_ERROR,
-
- /**
- * There was an error establishing the noise handshake
- */
- NOISE_HANDSHAKE_ERROR,
-
- INTERNAL_SERVER_ERROR
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java
deleted file mode 100644
index 00afa516d..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-
-/**
- * A proxy handler writes all data read from one channel to another peer channel.
- */
-public class ProxyHandler extends ChannelInboundHandlerAdapter {
-
- private final Channel peerChannel;
-
- public ProxyHandler(final Channel peerChannel) {
- this.peerChannel = peerChannel;
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) {
- peerChannel.writeAndFlush(message)
- .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java
deleted file mode 100644
index 3b8951658..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java
+++ /dev/null
@@ -1,73 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import com.google.common.annotations.VisibleForTesting;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.CompositeByteBuf;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
-
-/**
- * A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection.
- * If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline.
- * In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol
- * message, it will remove itself from the pipeline and pass any intercepted down the pipeline.
- *
- * @see The PROXY protocol
- */
-public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter {
-
- private CompositeByteBuf accumulator;
-
- @VisibleForTesting
- static final int PROXY_MESSAGE_DETECTION_BYTES = 12;
-
- @Override
- public void handlerAdded(final ChannelHandlerContext context) {
- // We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get
- // non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every
- // practical case, though, we'll be able to tell from the first packet.
- accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES);
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- if (message instanceof ByteBuf byteBuf) {
- accumulator.addComponent(true, byteBuf);
-
- switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) {
- case NEEDS_MORE_DATA -> {
- }
-
- case INVALID -> {
- // We have enough information to determine that this connection is NOT starting with a proxy protocol message,
- // and we can just pass the accumulated bytes through
- context.fireChannelRead(accumulator);
-
- accumulator = null;
- context.pipeline().remove(this);
- }
-
- case DETECTED -> {
- // We have enough information to know that we're dealing with a proxy protocol message; add appropriate
- // handlers and pass the accumulated bytes through
- context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder());
- context.fireChannelRead(accumulator);
-
- accumulator = null;
- context.pipeline().remove(this);
- }
- }
- } else {
- super.channelRead(context, message);
- }
- }
-
- @Override
- public void handlerRemoved(final ChannelHandlerContext context) {
- if (accumulator != null) {
- accumulator.release();
- accumulator = null;
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java
deleted file mode 100644
index 0ed2f938b..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.util.ReferenceCountUtil;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
-
-/**
- * In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
- * noise packet to the noise layer as a {@link ByteBuf} for decryption.
- *
- * In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
- * so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
- * handshake response, and then the subsequent messages are all data frame payloads.
- */
-public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (msg instanceof NoiseDirectFrame frame) {
- if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
- ReferenceCountUtil.release(msg);
- throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
- }
- ctx.fireChannelRead(frame.content());
- } else {
- ctx.fireChannelRead(msg);
- }
- }
-
- @Override
- public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
- if (msg instanceof ByteBuf bb) {
- ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
- } else {
- ctx.write(msg, promise);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java
deleted file mode 100644
index 3545d5797..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.DefaultByteBufHolder;
-
-public class NoiseDirectFrame extends DefaultByteBufHolder {
-
- static final byte VERSION = 0x00;
-
- private final FrameType frameType;
-
- public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
- super(data);
- this.frameType = frameType;
- }
-
- public FrameType frameType() {
- return frameType;
- }
-
- public byte versionedFrameTypeByte() {
- final byte frameBits = frameType().getFrameBits();
- return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
- }
-
-
- public enum FrameType {
- /**
- * The payload is the initiator message for a Noise NK handshake. If established, the
- * session will be unauthenticated.
- */
- NK_HANDSHAKE((byte) 1),
- /**
- * The payload is the initiator message for a Noise IK handshake. If established, the
- * session will be authenticated.
- */
- IK_HANDSHAKE((byte) 2),
- /**
- * The payload is an encrypted noise packet.
- */
- DATA((byte) 3),
- /**
- * A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
- * closed.
- */
- CLOSE((byte) 4);
-
- private final byte frameType;
-
- FrameType(byte frameType) {
- if (frameType != (0x0F & frameType)) {
- throw new IllegalStateException("Frame type must fit in 4 bits");
- }
- this.frameType = frameType;
- }
-
- public byte getFrameBits() {
- return frameType;
- }
-
- public boolean isHandshake() {
- return switch (this) {
- case IK_HANDSHAKE, NK_HANDSHAKE -> true;
- case DATA, CLOSE -> false;
- };
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java
deleted file mode 100644
index 4d94d0400..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.util.ReferenceCountUtil;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
-
-/**
- * Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
- * have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
- */
-public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (msg instanceof ByteBuf byteBuf) {
- try {
- ctx.fireChannelRead(deserialize(byteBuf));
- } catch (Exception e) {
- ReferenceCountUtil.release(byteBuf);
- throw e;
- }
- } else {
- ctx.fireChannelRead(msg);
- }
- }
-
- @Override
- public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
- if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
- try {
- // Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
- // network, nothing should have to make another copy of this. If later another layer is added, it may be more
- // efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
- final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
- ctx.writeAndFlush(serialized, promise);
- } finally {
- ReferenceCountUtil.release(noiseDirectFrame);
- }
- } else {
- ctx.write(msg, promise);
- }
- }
-
- private ByteBuf serialize(
- final ChannelHandlerContext ctx,
- final NoiseDirectFrame noiseDirectFrame) {
- if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
- throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
- }
-
- // 1 version/frametype byte, 2 length bytes, content
- final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
-
- byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
- byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
- byteBuf.writeBytes(noiseDirectFrame.content());
- return byteBuf;
- }
-
- private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
- final byte versionAndFrameByte = byteBuf.readByte();
- final int version = (versionAndFrameByte & 0xF0) >> 4;
- if (version != NoiseDirectFrame.VERSION) {
- throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
- }
- final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
- final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
- case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
- case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
- case 3 -> NoiseDirectFrame.FrameType.DATA;
- case 4 -> NoiseDirectFrame.FrameType.CLOSE;
- default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
- };
-
- final int length = Short.toUnsignedInt(byteBuf.readShort());
- if (length != byteBuf.readableBytes()) {
- throw new IllegalArgumentException(
- "Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
- }
- return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java
deleted file mode 100644
index bf72bf927..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.util.ReferenceCountUtil;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
-
-/**
- * Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
- * forwards the handshake frame along as a {@link NoiseHandshakeInit} message
- */
-public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (msg instanceof NoiseDirectFrame frame) {
- try {
- if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
- throw new IOException("Could not determine remote address");
- }
- // We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
- // needs into a NoiseHandshakeInit message and forward that along
- final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
- switch (frame.frameType()) {
- case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
- case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
- case IK_HANDSHAKE -> HandshakePattern.IK;
- case NK_HANDSHAKE -> HandshakePattern.NK;
- }, frame.content());
-
- // Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
- // should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
- // for network serialization. Note that we need to install the Data frame handler before firing the read,
- // because we may receive an outbound message from the noiseHandler
- ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
- ctx.fireChannelRead(handshakeMessage);
- } catch (Exception e) {
- ReferenceCountUtil.release(msg);
- throw e;
- }
- } else {
- ctx.fireChannelRead(msg);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java
deleted file mode 100644
index 1548bdb98..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import io.micrometer.core.instrument.Metrics;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.util.ReferenceCountUtil;
-import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
-
-
-/**
- * Watches for inbound close frames and closes the connection in response
- */
-public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
- private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "clientClose");
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
- try {
- final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
- .parseFrom(ByteBufUtil.getBytes(ndf.content()));
-
- Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
- } finally {
- ReferenceCountUtil.release(msg);
- ctx.close();
- }
- } else {
- ctx.fireChannelRead(msg);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java
deleted file mode 100644
index b7b790e69..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java
+++ /dev/null
@@ -1,43 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import io.micrometer.core.instrument.Metrics;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufOutputStream;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelOutboundHandlerAdapter;
-import io.netty.channel.ChannelPromise;
-import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
-import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
-
-/**
- * Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
- * written, the channel is closed
- */
-class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
- private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "serverClose");
-
- @Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
- if (msg instanceof OutboundCloseErrorMessage err) {
- final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
- case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
- case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
- case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
- case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
- };
- Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "reason", code.name()).increment();
-
- final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
- .setCode(code)
- .setMessage(err.message())
- .build();
- final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
- proto.writeTo(new ByteBufOutputStream(byteBuf));
- ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
- .addListener(ChannelFutureListener.CLOSE);
- } else {
- ctx.write(msg, promise);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java
deleted file mode 100644
index ee3680826..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java
+++ /dev/null
@@ -1,96 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.southernstorm.noise.protocol.Noise;
-import io.dropwizard.lifecycle.Managed;
-import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.ChannelInitializer;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.nio.NioEventLoopGroup;
-import io.netty.channel.socket.ServerSocketChannel;
-import io.netty.channel.socket.SocketChannel;
-import io.netty.channel.socket.nio.NioServerSocketChannel;
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
-import java.net.InetSocketAddress;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
-import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
-import org.whispersystems.textsecuregcm.grpc.net.FramingType;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
-import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
-import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
-import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
-
-/**
- * A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
- * binary framing protocol) and passes it through to a local gRPC server.
- */
-public class NoiseDirectTunnelServer implements Managed {
-
- private final ServerBootstrap bootstrap;
- private ServerSocketChannel channel;
-
- private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
-
- public NoiseDirectTunnelServer(final int port,
- final NioEventLoopGroup eventLoopGroup,
- final GrpcClientConnectionManager grpcClientConnectionManager,
- final ClientPublicKeysManager clientPublicKeysManager,
- final ECKeyPair ecKeyPair,
- final LocalAddress authenticatedGrpcServerAddress,
- final LocalAddress anonymousGrpcServerAddress) {
-
- this.bootstrap = new ServerBootstrap()
- .group(eventLoopGroup)
- .channel(NioServerSocketChannel.class)
- .localAddress(port)
- .childHandler(new ChannelInitializer() {
- @Override
- protected void initChannel(SocketChannel socketChannel) {
- socketChannel.pipeline()
- .addLast(new ProxyProtocolDetectionHandler())
- .addLast(new HAProxyMessageHandler())
- // frame byte followed by a 2-byte length field
- .addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
- // Parses NoiseDirectFrames from wire bytes and vice versa
- .addLast(new NoiseDirectFrameCodec())
- // Terminate the connection if the client sends us a close frame
- .addLast(new NoiseDirectInboundCloseHandler())
- // Turn generic OutboundCloseErrorMessages into noise direct error frames
- .addLast(new NoiseDirectOutboundErrorHandler())
- // Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
- // NoiseDirectDataFrameCodec to handle subsequent data frames
- .addLast(new NoiseDirectHandshakeSelector())
- // Performs the noise handshake and then replace itself with a NoiseHandler
- .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
- // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
- // once the Noise handshake has completed
- .addLast(new EstablishLocalGrpcConnectionHandler(
- grpcClientConnectionManager,
- authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
- FramingType.NOISE_DIRECT))
- .addLast(new ErrorHandler());
- }
- });
- }
-
- @VisibleForTesting
- public InetSocketAddress getLocalAddress() {
- return channel.localAddress();
- }
-
- @Override
- public void start() throws InterruptedException {
- channel = (ServerSocketChannel) bootstrap.bind().await().channel();
- }
-
- @Override
- public void stop() throws InterruptedException {
- if (channel != null) {
- channel.close().await();
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java
deleted file mode 100644
index 67a24b99f..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
-
-enum ApplicationWebSocketCloseReason {
- NOISE_HANDSHAKE_ERROR(4001),
- NOISE_ENCRYPTION_ERROR(4002);
-
- private final int statusCode;
-
- ApplicationWebSocketCloseReason(final int statusCode) {
- this.statusCode = statusCode;
- }
-
- public int getStatusCode() {
- return statusCode;
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java
deleted file mode 100644
index eaff83bfa..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java
+++ /dev/null
@@ -1,146 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.southernstorm.noise.protocol.Noise;
-import io.dropwizard.lifecycle.Managed;
-import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.ChannelInitializer;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.nio.NioEventLoopGroup;
-import io.netty.channel.socket.ServerSocketChannel;
-import io.netty.channel.socket.SocketChannel;
-import io.netty.channel.socket.nio.NioServerSocketChannel;
-import io.netty.handler.codec.http.HttpObjectAggregator;
-import io.netty.handler.codec.http.HttpServerCodec;
-import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
-import io.netty.handler.ssl.ClientAuth;
-import io.netty.handler.ssl.OpenSsl;
-import io.netty.handler.ssl.SslContext;
-import io.netty.handler.ssl.SslContextBuilder;
-import io.netty.handler.ssl.SslProtocols;
-import io.netty.handler.ssl.SslProvider;
-import java.net.InetSocketAddress;
-import java.security.PrivateKey;
-import java.security.cert.X509Certificate;
-import java.util.concurrent.Executor;
-import javax.annotation.Nullable;
-import javax.net.ssl.SSLException;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.grpc.net.*;
-import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
-
-/**
- * A Noise-over-WebSocket tunnel server accepts traffic from the public internet (in the form of Noise packets framed by
- * binary WebSocket frames) and passes it through to a local gRPC server.
- */
-public class NoiseWebSocketTunnelServer implements Managed {
-
- private final ServerBootstrap bootstrap;
- private ServerSocketChannel channel;
-
- static final String AUTHENTICATED_SERVICE_PATH = "/authenticated";
- static final String ANONYMOUS_SERVICE_PATH = "/anonymous";
- static final String HEALTH_CHECK_PATH = "/health-check";
-
- private static final Logger log = LoggerFactory.getLogger(NoiseWebSocketTunnelServer.class);
-
- public NoiseWebSocketTunnelServer(final int websocketPort,
- @Nullable final X509Certificate[] tlsCertificateChain,
- @Nullable final PrivateKey tlsPrivateKey,
- final NioEventLoopGroup eventLoopGroup,
- final Executor delegatedTaskExecutor,
- final GrpcClientConnectionManager grpcClientConnectionManager,
- final ClientPublicKeysManager clientPublicKeysManager,
- final ECKeyPair ecKeyPair,
- final LocalAddress authenticatedGrpcServerAddress,
- final LocalAddress anonymousGrpcServerAddress,
- final String recognizedProxySecret) throws SSLException {
-
- @Nullable final SslContext sslContext;
-
- if (tlsCertificateChain != null && tlsPrivateKey != null) {
- final SslProvider sslProvider;
-
- if (OpenSsl.isAvailable()) {
- log.info("Native OpenSSL provider is available; will use native provider");
- sslProvider = SslProvider.OPENSSL;
- } else {
- log.info("No native SSL provider available; will use JDK provider");
- sslProvider = SslProvider.JDK;
- }
-
- sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
- .clientAuth(ClientAuth.NONE)
- // Some load balancers require TLS 1.2 for health checks
- .protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2)
- .sslProvider(sslProvider)
- .build();
- } else {
- log.warn("No TLS credentials provided; Noise-over-WebSocket tunnel will not use TLS. This configuration is not suitable for production environments.");
- sslContext = null;
- }
-
- this.bootstrap = new ServerBootstrap()
- .group(eventLoopGroup)
- .channel(NioServerSocketChannel.class)
- .localAddress(websocketPort)
- .childHandler(new ChannelInitializer() {
- @Override
- protected void initChannel(SocketChannel socketChannel) {
- socketChannel.pipeline()
- .addLast(new ProxyProtocolDetectionHandler())
- .addLast(new HAProxyMessageHandler());
-
- if (sslContext != null) {
- socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
- }
-
- socketChannel.pipeline()
- .addLast(new HttpServerCodec())
- .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
- // The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade
- // request and passed it down the pipeline
- .addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
- .addLast(new WebSocketServerProtocolHandler("/", true))
- // Metrics on inbound/outbound Close frames
- .addLast(new WebSocketCloseMetricHandler())
- // Turn generic OutboundCloseErrorMessages into websocket close frames
- .addLast(new WebSocketOutboundErrorHandler())
- .addLast(new RejectUnsupportedMessagesHandler())
- .addLast(new WebsocketPayloadCodec())
- // The WebSocket handshake complete listener will forward the first payload supplemented with
- // data from the websocket handshake completion event, and then remove itself from the pipeline
- .addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
- // The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
- // NoiseHandler
- .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
- // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
- // once the Noise handshake has completed
- .addLast(new EstablishLocalGrpcConnectionHandler(
- grpcClientConnectionManager,
- authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
- FramingType.WEBSOCKET))
- .addLast(new ErrorHandler());
- }
- });
- }
-
- @VisibleForTesting
- InetSocketAddress getLocalAddress() {
- return channel.localAddress();
- }
-
- @Override
- public void start() throws InterruptedException {
- channel = (ServerSocketChannel) bootstrap.bind().await().channel();
- }
-
- @Override
- public void stop() throws InterruptedException {
- if (channel != null) {
- channel.close().await();
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java
deleted file mode 100644
index d313951b3..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java
+++ /dev/null
@@ -1,35 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
-import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
-import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
-import io.netty.handler.codec.http.websocketx.WebSocketFrame;
-import io.netty.util.ReferenceCountUtil;
-
-/**
- * A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process.
- */
-public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter {
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- if (message instanceof final WebSocketFrame webSocketFrame) {
- if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) {
- try {
- context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
- } finally {
- textWebSocketFrame.release();
- }
- } else {
- // Allow all other types of WebSocket frames
- context.fireChannelRead(webSocketFrame);
- }
- } else {
- // Discard anything that's not a WebSocket frame
- ReferenceCountUtil.release(message);
- context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java
deleted file mode 100644
index 52fa3de2d..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.micrometer.core.instrument.Metrics;
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
-import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
-
-public class WebSocketCloseMetricHandler extends ChannelDuplexHandler {
-
- private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "clientClose");
- private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "serverClose");
-
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (msg instanceof CloseWebSocketFrame closeFrame) {
- Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "closeCode", validatedCloseCode(closeFrame.statusCode())).increment();
- }
- ctx.fireChannelRead(msg);
- }
-
-
- @Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
- if (msg instanceof CloseWebSocketFrame closeFrame) {
- Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "closeCode", Integer.toString(closeFrame.statusCode())).increment();
- }
- ctx.write(msg, promise);
- }
-
- private static String validatedCloseCode(int closeCode) {
-
- if (closeCode >= 1000 && closeCode <= 1015) {
- // RFC-6455 pre-defined status codes
- return Integer.toString(closeCode);
- } else if (closeCode >= 4000 && closeCode <= 4100) {
- // Application status codes
- return Integer.toString(closeCode);
- } else {
- return "unknown";
- }
- }
-
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java
deleted file mode 100644
index e70cdec8e..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java
+++ /dev/null
@@ -1,81 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.codec.http.DefaultFullHttpResponse;
-import io.netty.handler.codec.http.FullHttpRequest;
-import io.netty.handler.codec.http.HttpHeaderNames;
-import io.netty.handler.codec.http.HttpHeaderValues;
-import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpResponseStatus;
-import io.netty.util.ReferenceCountUtil;
-
-/**
- * A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully
- * rejects requests for anything other than a WebSocket connection to a known endpoint.
- */
-class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter {
-
- private final String authenticatedPath;
- private final String anonymousPath;
- private final String healthCheckPath;
-
- WebSocketOpeningHandshakeHandler(final String authenticatedPath,
- final String anonymousPath,
- final String healthCheckPath) {
-
- this.authenticatedPath = authenticatedPath;
- this.anonymousPath = anonymousPath;
- this.healthCheckPath = healthCheckPath;
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
- if (message instanceof FullHttpRequest request) {
- boolean shouldReleaseRequest = true;
-
- try {
- if (request.decoderResult().isSuccess()) {
- if (HttpMethod.GET.equals(request.method())) {
- if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) {
- if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
- // Pass the request along to the websocket handshake handler and remove ourselves from the pipeline
- shouldReleaseRequest = false;
-
- context.fireChannelRead(request);
- context.pipeline().remove(this);
- } else {
- closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED);
- }
- } else if (healthCheckPath.equals(request.uri())) {
- closeConnectionWithStatus(context, request, HttpResponseStatus.NO_CONTENT);
- } else {
- closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND);
- }
- } else {
- closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED);
- }
- } else {
- closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST);
- }
- } finally {
- if (shouldReleaseRequest) {
- request.release();
- }
- }
- } else {
- // Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error
- ReferenceCountUtil.release(message);
- throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
- }
- }
-
- private static void closeConnectionWithStatus(final ChannelHandlerContext context,
- final FullHttpRequest request,
- final HttpResponseStatus status) {
-
- context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status))
- .addListener(ChannelFutureListener.CLOSE);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java
deleted file mode 100644
index 421d575bb..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java
+++ /dev/null
@@ -1,60 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
-import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
-import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
-import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
-
-/**
- * Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
- */
-class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
- private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketOutboundErrorHandler.class, "serverClose");
-
- private boolean websocketHandshakeComplete = false;
-
- private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
-
- @Override
- public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
- if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
- setWebsocketHandshakeComplete();
- }
-
- context.fireUserEventTriggered(event);
- }
-
- protected void setWebsocketHandshakeComplete() {
- this.websocketHandshakeComplete = true;
- }
-
- @Override
- public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
- if (msg instanceof OutboundCloseErrorMessage err) {
- if (websocketHandshakeComplete) {
- final int status = switch (err.code()) {
- case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
- case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
- case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
- case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
- };
- ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
- .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
- } else {
- log.debug("Error {} occurred before websocket handshake complete", err);
- // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
- // way; just close the connection instead.
- ctx.close();
- }
- } else {
- ctx.write(msg, promise);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java
deleted file mode 100644
index 7093934bd..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java
+++ /dev/null
@@ -1,152 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.net.InetAddresses;
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
-import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
-import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
-import io.netty.util.ReferenceCountUtil;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
-import java.util.Optional;
-import javax.annotation.Nullable;
-import org.apache.commons.lang3.StringUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
-
-/**
- * A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
- * Noise handshake handler for the requested path.
- */
-class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
-
- private final byte[] recognizedProxySecret;
-
- private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
-
- @VisibleForTesting
- static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
-
- @VisibleForTesting
- static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
-
- private InetAddress remoteAddress = null;
- private HandshakePattern handshakePattern = null;
-
- WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
-
- // The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
- // encoded value). We convert it into a byte array here for easier constant-time comparisons via
- // MessageDigest.equals() later.
- this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
- }
-
- @Override
- public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
- if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
- final Optional maybePreferredRemoteAddress =
- getPreferredRemoteAddress(context, handshakeCompleteEvent);
-
- if (maybePreferredRemoteAddress.isEmpty()) {
- context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
- "Could not determine remote address"))
- .addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
-
- return;
- }
-
- remoteAddress = maybePreferredRemoteAddress.get();
- handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
- case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
- case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
- // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
- // internal error if something slipped through.
- default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
- };
- }
-
- context.fireUserEventTriggered(event);
- }
-
- @Override
- public void channelRead(final ChannelHandlerContext context, final Object msg) {
- try {
- if (!(msg instanceof ByteBuf frame)) {
- throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
- }
-
- if (handshakePattern == null || remoteAddress == null) {
- throw new IllegalStateException("Received payload before websocket handshake complete");
- }
-
- final NoiseHandshakeInit handshakeMessage =
- new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
-
- context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
- context.fireChannelRead(handshakeMessage);
- } catch (Exception e) {
- ReferenceCountUtil.release(msg);
- throw e;
- }
- }
-
- private Optional getPreferredRemoteAddress(final ChannelHandlerContext context,
- final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
-
- final byte[] recognizedProxySecretFromHeader =
- handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
- .getBytes(StandardCharsets.UTF_8);
-
- final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
-
- if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
- final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
-
- return getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
- try {
- return InetAddresses.forString(mostRecentProxy);
- } catch (final IllegalArgumentException e) {
- log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
- return null;
- }
- });
- } else {
- // Either we don't trust the forwarded-for header or it's not present
- if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
- return Optional.of(inetSocketAddress.getAddress());
- } else {
- log.warn("Channel's remote address was not an InetSocketAddress");
- return Optional.empty();
- }
- }
- }
-
- /**
- * Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
- *
- * @param forwardedFor the value of an X-Forwarded-For header
- * @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
- * {@code forwardedFor} was null
- * @see X-Forwarded-For - HTTP |
- * MDN
- */
- @VisibleForTesting
- static Optional getMostRecentProxy(@Nullable final String forwardedFor) {
- return Optional.ofNullable(forwardedFor)
- .map(ff -> {
- final int idx = forwardedFor.lastIndexOf(',') + 1;
- return idx < forwardedFor.length()
- ? forwardedFor.substring(idx).trim()
- : null;
- })
- .filter(StringUtils::isNotBlank);
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java
deleted file mode 100644
index 4ef5bd151..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.grpc.net.websocket;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
-
-/**
- * Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
- * {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
- * packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
- */
-public class WebsocketPayloadCodec extends ChannelDuplexHandler {
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
- if (msg instanceof BinaryWebSocketFrame frame) {
- ctx.fireChannelRead(frame.content());
- } else {
- ctx.fireChannelRead(msg);
- }
- }
-
- @Override
- public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
- if (msg instanceof ByteBuf bb) {
- ctx.write(new BinaryWebSocketFrame(bb), promise);
- } else {
- ctx.write(msg, promise);
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
index e5386896e..d283c0c3c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java
@@ -15,11 +15,7 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
-import java.security.InvalidKeyException;
-import java.security.NoSuchAlgorithmException;
-import java.security.cert.CertificateException;
import java.time.Clock;
-import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
@@ -43,7 +39,6 @@ import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
@@ -267,9 +262,8 @@ record CommandDependencies(
() -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, retryExecutor, configuration.getSecureStorageServiceConfiguration());
- GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient,
- grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor);
+ disconnectionRequestListenerExecutor, retryExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, retryExecutor, Clock.systemUTC(), experimentEnrollmentManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client,
diff --git a/service/src/main/proto/NoiseDirect.proto b/service/src/main/proto/NoiseDirect.proto
deleted file mode 100644
index a5c1ecf8c..000000000
--- a/service/src/main/proto/NoiseDirect.proto
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-syntax = "proto3";
-
-option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect";
-option java_outer_classname = "NoiseDirectProtos";
-
-message CloseReason {
- enum Code {
- UNSPECIFIED = 0;
- // Indicates non-error termination
- // Examples:
- // - The client is finished with the connection
- OK = 1;
-
- // There was an issue with the handshake. If sent after a handshake response,
- // the response includes more information about the nature of the error
- // Examples:
- // - The client did not provide a handshake message
- // - The client had incorrect authentication credentials. The handshake
- // payload includes additional details
- HANDSHAKE_ERROR = 2;
-
- // There was an encryption/decryption issue after the handshake
- // Examples:
- // - The client incorrectly encrypted a noise message and it had a bad
- // AEAD tag
- ENCRYPTION_ERROR = 3;
-
- // The server is temporarily unavailable, going away, or requires a
- // connection reset
- // Examples:
- // - The server is shutting down
- // - The client’s authentication credentials have been rotated
- UNAVAILABLE = 4;
-
- // There was an an internal error
- // Examples:
- // - The server experienced a temporary database outage that prevented it
- // from checking the client's credentials
- INTERNAL_ERROR = 5;
- }
-
- Code code = 1;
-
- // If present, includes details about the error. Implementations should never
- // parse or otherwise implement conditional logic based on the contents of the
- // error message string, it is for logging and debugging purposes only.
- string message = 2;
-}
diff --git a/service/src/main/proto/NoiseTunnel.proto b/service/src/main/proto/NoiseTunnel.proto
deleted file mode 100644
index 45695e20a..000000000
--- a/service/src/main/proto/NoiseTunnel.proto
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-syntax = "proto3";
-
-option java_package = "org.whispersystems.textsecuregcm.grpc.net";
-option java_outer_classname = "NoiseTunnelProtos";
-
-message HandshakeInit {
- string user_agent = 1;
-
- // An Accept-Language as described in
- // https://httpwg.org/specs/rfc9110.html#field.accept-language
- string accept_language = 2;
-
- // A UUID serialized as 16 bytes (big end first). Must be unset (empty) for an
- // unauthenticated handshake
- bytes aci = 3;
-
- // The deviceId, 0 < deviceId < 128. Must be unset for an unauthenticated
- // handshake
- uint32 device_id = 4;
-
- // The first bytes of the application request byte stream, may contain less
- // than a full request
- bytes fast_open_request = 5;
-}
-
-message HandshakeResponse {
- enum Code {
- UNSPECIFIED = 0;
-
- // The noise session may be used to send application layer requests
- OK = 1;
-
- // The provided client static key did not match the registered public key
- // for the provided aci/deviceId.
- WRONG_PUBLIC_KEY = 2;
-
- // The client version is to old, it should be upgraded before retrying
- DEPRECATED = 3;
- }
-
- // The handshake outcome
- Code code = 1;
-
- // Additional information about an error status, for debugging only
- string error_details = 2;
-
- // An optional response to a fast_open_request provided in the HandshakeInit.
- // Note that a response may not be present even if a fast_open_request was
- // present. If so, the response will be returned in a later message.
- bytes fast_open_response = 3;
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java
index cfbf2cd94..7c3f78f7a 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java
@@ -20,8 +20,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -30,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class DisconnectionRequestManagerTest {
- private GrpcClientConnectionManager grpcClientConnectionManager;
private DisconnectionRequestManager disconnectionRequestManager;
@RegisterExtension
@@ -38,10 +35,8 @@ class DisconnectionRequestManagerTest {
@BeforeEach
void setUp() {
- grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(),
- grpcClientConnectionManager,
Runnable::run,
mock(ScheduledExecutorService.class));
@@ -103,16 +98,8 @@ class DisconnectionRequestManagerTest {
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
- verify(grpcClientConnectionManager, timeout(1_000))
- .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
-
- verify(grpcClientConnectionManager, timeout(1_000))
- .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
disconnectionRequestManager.requestDisconnection(otherAccountIdentifier, List.of(otherDeviceId));
-
- verify(grpcClientConnectionManager, timeout(1_000))
- .closeConnection(new AuthenticatedDevice(otherAccountIdentifier, otherDeviceId));
}
@Test
@@ -141,11 +128,5 @@ class DisconnectionRequestManagerTest {
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
-
- verify(grpcClientConnectionManager, timeout(1_000))
- .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
-
- verify(grpcClientConnectionManager, timeout(1_000))
- .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java
deleted file mode 100644
index 95bacd076..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java
+++ /dev/null
@@ -1,77 +0,0 @@
-package org.whispersystems.textsecuregcm.auth.grpc;
-
-import static org.mockito.Mockito.mock;
-
-import io.grpc.ManagedChannel;
-import io.grpc.Server;
-import io.grpc.netty.NettyChannelBuilder;
-import io.grpc.netty.NettyServerBuilder;
-import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.local.LocalChannel;
-import io.netty.channel.local.LocalServerChannel;
-import java.io.IOException;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.BeforeEach;
-import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
-import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
-import org.signal.chat.rpc.RequestAttributesGrpc;
-import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
-
-abstract class AbstractAuthenticationInterceptorTest {
-
- private static DefaultEventLoopGroup eventLoopGroup;
-
- private GrpcClientConnectionManager grpcClientConnectionManager;
-
- private Server server;
- private ManagedChannel managedChannel;
-
- @BeforeAll
- static void setUpBeforeAll() {
- eventLoopGroup = new DefaultEventLoopGroup();
- }
-
- @BeforeEach
- void setUp() throws IOException {
- final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
-
- grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
-
- // `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
- // sure that we're using local channels and addresses
- server = NettyServerBuilder.forAddress(serverAddress)
- .channelType(LocalServerChannel.class)
- .bossEventLoopGroup(eventLoopGroup)
- .workerEventLoopGroup(eventLoopGroup)
- .intercept(getInterceptor())
- .addService(new RequestAttributesServiceImpl())
- .build()
- .start();
-
- managedChannel = NettyChannelBuilder.forAddress(serverAddress)
- .channelType(LocalChannel.class)
- .eventLoopGroup(eventLoopGroup)
- .usePlaintext()
- .build();
- }
-
- @AfterEach
- void tearDown() {
- managedChannel.shutdown();
- server.shutdown();
- }
-
- protected abstract AbstractAuthenticationInterceptor getInterceptor();
-
- protected GrpcClientConnectionManager getClientConnectionManager() {
- return grpcClientConnectionManager;
- }
-
- protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
- return RequestAttributesGrpc.newBlockingStub(managedChannel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java
new file mode 100644
index 000000000..98e2c676a
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2025 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+package org.whispersystems.textsecuregcm.auth.grpc;
+
+import io.grpc.CallCredentials;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import java.util.concurrent.Executor;
+import org.whispersystems.textsecuregcm.util.HeaderUtils;
+
+public class BasicAuthCallCredentials extends CallCredentials {
+
+ private final String username;
+ private final String password;
+
+ public BasicAuthCallCredentials(String username, String password) {
+ this.username = username;
+ this.password = password;
+ }
+
+ @Override
+ public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor,
+ final MetadataApplier applier) {
+ try {
+ Metadata headers = new Metadata();
+ headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER),
+ HeaderUtils.basicAuthHeader(username, password));
+ applier.apply(headers);
+ } catch (Exception e) {
+ applier.fail(Status.UNAUTHENTICATED.withCause(e));
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java
index a0d8c6688..189af0bd0 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java
@@ -1,44 +1,64 @@
package org.whispersystems.textsecuregcm.auth.grpc;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
-import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
-import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
-import org.whispersystems.textsecuregcm.storage.Device;
+import org.signal.chat.rpc.EchoRequest;
+import org.signal.chat.rpc.EchoServiceGrpc;
+import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
-import java.util.Optional;
-import java.util.UUID;
+import java.util.concurrent.TimeUnit;
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.when;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
+class ProhibitAuthenticationInterceptorTest {
+ private Server server;
+ private ManagedChannel channel;
- @Override
- protected AbstractAuthenticationInterceptor getInterceptor() {
- return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
+ @BeforeEach
+ void setUp() throws Exception {
+ server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
+ .directExecutor()
+ .intercept(new ProhibitAuthenticationInterceptor())
+ .addService(new EchoServiceImpl())
+ .build()
+ .start();
+
+ channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
+ .directExecutor()
+ .build();
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ channel.shutdownNow();
+ server.shutdownNow();
+ channel.awaitTermination(5, TimeUnit.SECONDS);
+ server.awaitTermination(5, TimeUnit.SECONDS);
}
@Test
- void interceptCall() throws ChannelNotFoundException {
- final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
+ void hasAuth() {
+ final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc
+ .newBlockingStub(channel)
+ .withCallCredentials(new BasicAuthCallCredentials("test", "password"));
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
+ final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
+ () -> client.echo(EchoRequest.getDefaultInstance()));
+ assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
+ }
- final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
- assertTrue(response.getAccountIdentifier().isEmpty());
- assertEquals(0, response.getDeviceId());
-
- final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
-
- GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
-
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
-
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
+ @Test
+ void noAuth() {
+ final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
+ assertDoesNotThrow(() -> client.echo(EchoRequest.getDefaultInstance()));
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java
index 442b38cc8..90bb86f79 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java
@@ -1,44 +1,102 @@
package org.whispersystems.textsecuregcm.auth.grpc;
+import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+import io.dropwizard.auth.basic.BasicCredentials;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import java.time.Instant;
import java.util.Optional;
import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+import org.junit.Assert;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
-import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
-import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
-import org.whispersystems.textsecuregcm.storage.Device;
+import org.signal.chat.rpc.GetRequestAttributesRequest;
+import org.signal.chat.rpc.RequestAttributesGrpc;
+import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
+import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
-class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
+class RequireAuthenticationInterceptorTest {
+ private Server server;
+ private ManagedChannel channel;
+ private AccountAuthenticator authenticator;
- @Override
- protected AbstractAuthenticationInterceptor getInterceptor() {
- return new RequireAuthenticationInterceptor(getClientConnectionManager());
+ @BeforeEach
+ void setUp() throws Exception {
+ authenticator = mock(AccountAuthenticator.class);
+ server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
+ .directExecutor()
+ .intercept(new RequireAuthenticationInterceptor(authenticator))
+ .addService(new RequestAttributesServiceImpl())
+ .build()
+ .start();
+
+ channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
+ .directExecutor()
+ .build();
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ channel.shutdownNow();
+ server.shutdownNow();
+ channel.awaitTermination(5, TimeUnit.SECONDS);
+ server.awaitTermination(5, TimeUnit.SECONDS);
}
@Test
- void interceptCall() throws ChannelNotFoundException {
- final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
+ void hasAuth() {
+ final UUID aci = UUID.randomUUID();
+ final byte deviceId = 2;
+ when(authenticator.authenticate(eq(new BasicCredentials("test", "password"))))
+ .thenReturn(Optional.of(
+ new org.whispersystems.textsecuregcm.auth.AuthenticatedDevice(aci, deviceId, Instant.now())));
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
+ final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
+ .newBlockingStub(channel)
+ .withCallCredentials(new BasicAuthCallCredentials("test", "password"));
- GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
+ final GetAuthenticatedDeviceResponse authenticatedDevice = client.getAuthenticatedDevice(
+ GetAuthenticatedDeviceRequest.getDefaultInstance());
+ assertEquals(authenticatedDevice.getDeviceId(), deviceId);
+ assertEquals(UUIDUtil.fromByteString(authenticatedDevice.getAccountIdentifier()), aci);
+ }
- final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
+ @Test
+ void badCredentials() {
+ when(authenticator.authenticate(any())).thenReturn(Optional.empty());
- final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
- assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
- assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
+ final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
+ .newBlockingStub(channel)
+ .withCallCredentials(new BasicAuthCallCredentials("test", "password"));
- when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
+ final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
+ () -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
+ Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
+ }
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
+ @Test
+ void missingCredentials() {
+ when(authenticator.authenticate(any())).thenReturn(Optional.empty());
+
+ final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc.newBlockingStub(channel);
+
+ final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
+ () -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
+ Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java
deleted file mode 100644
index 35db29a77..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.grpc;
-
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import io.grpc.Metadata;
-import io.grpc.ServerCall;
-import io.grpc.ServerCallHandler;
-import io.grpc.Status;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
-
-class ChannelShutdownInterceptorTest {
-
- private GrpcClientConnectionManager grpcClientConnectionManager;
- private ChannelShutdownInterceptor channelShutdownInterceptor;
-
- private ServerCallHandler nextCallHandler;
-
- private static final Metadata HEADERS = new Metadata();
-
- @BeforeEach
- void setUp() {
- grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
- channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
-
- //noinspection unchecked
- nextCallHandler = mock(ServerCallHandler.class);
-
- //noinspection unchecked
- when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
- }
-
- @Test
- void interceptCallComplete() {
- @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class);
-
- when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
-
- final ServerCall.Listener serverCallListener =
- channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
-
- serverCallListener.onComplete();
-
- verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
- verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
- verify(serverCall, never()).close(any(), any());
- }
-
- @Test
- void interceptCallCancelled() {
- @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class);
-
- when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
-
- final ServerCall.Listener serverCallListener =
- channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
-
- serverCallListener.onCancel();
-
- verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
- verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
- verify(serverCall, never()).close(any(), any());
- }
-
- @Test
- void interceptCallChannelClosing() {
- @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class);
-
- when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
-
- channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
-
- verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
- verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
- verify(serverCall).close(eq(Status.UNAVAILABLE), any());
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java
new file mode 100644
index 000000000..e30dd9549
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright 2025 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+package org.whispersystems.textsecuregcm.grpc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import io.grpc.ManagedChannel;
+import io.grpc.Metadata;
+import io.grpc.Server;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.netty.NettyChannelBuilder;
+import io.grpc.netty.NettyServerBuilder;
+import io.grpc.stub.MetadataUtils;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.signal.chat.rpc.GetRequestAttributesRequest;
+import org.signal.chat.rpc.GetRequestAttributesResponse;
+import org.signal.chat.rpc.RequestAttributesGrpc;
+
+public class RequestAttributesInterceptorTest {
+
+ private static String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)";
+ private Server server;
+ private AtomicBoolean removeUserAgent;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ removeUserAgent = new AtomicBoolean(false);
+
+ server = NettyServerBuilder.forPort(0)
+ .directExecutor()
+ .intercept(new RequestAttributesInterceptor())
+ // the grpc client always inserts a user-agent if we don't set one, so to test missing UAs we remove the header
+ // on the server-side
+ .intercept(new ServerInterceptor() {
+ @Override
+ public ServerCall.Listener interceptCall(final ServerCall call,
+ final Metadata headers, final ServerCallHandler next) {
+ if (removeUserAgent.get()) {
+ headers.removeAll(Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER));
+ }
+ return next.startCall(call, headers);
+ }
+ })
+ .addService(new RequestAttributesServiceImpl())
+ .build()
+ .start();
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ server.shutdownNow();
+ server.awaitTermination(1, TimeUnit.SECONDS);
+ }
+
+ private static List handleInvalidAcceptLanguage() {
+ return List.of(
+ Arguments.argumentSet("Null Accept-Language header", Optional.empty()),
+ Arguments.argumentSet("Empty Accept-Language header", Optional.of("")),
+ Arguments.argumentSet("Invalid Accept-Language header", Optional.of("This is not a valid language preference list")));
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void handleInvalidAcceptLanguage(Optional acceptLanguageHeader) throws Exception {
+ final Metadata metadata = new Metadata();
+ acceptLanguageHeader.ifPresent(h -> metadata
+ .put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), h));
+ final GetRequestAttributesResponse response = getRequestAttributes(metadata);
+ assertEquals(response.getAcceptableLanguagesCount(), 0);
+ }
+
+ @Test
+ void handleMissingUserAgent() throws InterruptedException {
+ removeUserAgent.set(true);
+ final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
+ assertEquals("", response.getUserAgent());
+ }
+
+ @Test
+ void allAttributes() throws InterruptedException {
+ final Metadata metadata = new Metadata();
+ metadata.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), "ja,en;q=0.4");
+ metadata.put(Metadata.Key.of("x-forwarded-for", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.3");
+ final GetRequestAttributesResponse response = getRequestAttributes(metadata);
+
+ assertTrue(response.getUserAgent().contains(USER_AGENT));
+ assertEquals("127.0.0.3", response.getRemoteAddress());
+ assertEquals(2, response.getAcceptableLanguagesCount());
+ assertEquals("ja", response.getAcceptableLanguages(0));
+ assertEquals("en;q=0.4", response.getAcceptableLanguages(1));
+ }
+
+ @Test
+ void useSocketAddrIfHeaderMissing() throws InterruptedException {
+ final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
+ assertEquals("127.0.0.1", response.getRemoteAddress());
+ }
+
+ private GetRequestAttributesResponse getRequestAttributes(Metadata metadata)
+ throws InterruptedException {
+ final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", server.getPort())
+ .directExecutor()
+ .usePlaintext()
+ .userAgent(USER_AGENT)
+ .build();
+ try {
+ final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
+ .newBlockingStub(channel)
+ .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
+ return client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance());
+ } finally {
+ channel.shutdownNow();
+ channel.awaitTermination(1, TimeUnit.SECONDS);
+ }
+ }
+
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java
deleted file mode 100644
index 9cee9e7f1..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java
+++ /dev/null
@@ -1,21 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import io.netty.util.ResourceLeakDetector;
-import org.junit.jupiter.api.AfterAll;
-import org.junit.jupiter.api.BeforeAll;
-
-public abstract class AbstractLeakDetectionTest {
-
- private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
-
- @BeforeAll
- static void setLeakDetectionLevel() {
- originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
- ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
- }
-
- @AfterAll
- static void restoreLeakDetectionLevel() {
- ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java
deleted file mode 100644
index 78d328cd8..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java
+++ /dev/null
@@ -1,325 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertInstanceOf;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.mockito.Mockito.mock;
-
-import com.southernstorm.noise.protocol.CipherStatePair;
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
-import io.netty.util.ReferenceCountUtil;
-import java.net.InetAddress;
-import java.net.UnknownHostException;
-import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.concurrent.ThreadLocalRandom;
-import javax.annotation.Nullable;
-import javax.crypto.AEADBadTagException;
-import javax.crypto.BadPaddingException;
-import javax.crypto.ShortBufferException;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.ValueSource;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
-import org.whispersystems.textsecuregcm.util.TestRandomUtil;
-
-abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
-
- protected ECKeyPair serverKeyPair;
- protected ClientPublicKeysManager clientPublicKeysManager;
-
- private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
-
- private EmbeddedChannel embeddedChannel;
-
- static final String USER_AGENT = "Test/User-Agent";
- static final String ACCEPT_LANGUAGE = "test-lang";
- static final InetAddress REMOTE_ADDRESS;
- static {
- try {
- REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3});
- } catch (UnknownHostException e) {
- throw new RuntimeException(e);
- }
- }
-
- private static class PongHandler extends ChannelInboundHandlerAdapter {
-
- @Override
- public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
- try {
- if (msg instanceof ByteBuf bb) {
- if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) {
- ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes()))
- .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
- } else {
- throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb)));
- }
- } else {
- throw new IllegalArgumentException("Unexpected message type: " + msg);
- }
- } finally {
- ReferenceCountUtil.release(msg);
- }
- }
- }
-
- private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
-
- @Nullable
- private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null;
-
- @Override
- public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
- if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
- handshakeCompleteEvent = noiseIdentityDeterminedEvent;
- context.pipeline().addAfter(context.name(), null, new PongHandler());
- context.pipeline().remove(NoiseHandshakeCompleteHandler.class);
- } else {
- context.fireUserEventTriggered(event);
- }
- }
-
- @Nullable
- public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() {
- return handshakeCompleteEvent;
- }
- }
-
- @BeforeEach
- void setUp() {
- serverKeyPair = ECKeyPair.generate();
- noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
- clientPublicKeysManager = mock(ClientPublicKeysManager.class);
- embeddedChannel = new EmbeddedChannel(
- new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair),
- noiseHandshakeCompleteHandler);
- }
-
- @AfterEach
- void tearDown() {
- embeddedChannel.close();
- }
-
- protected EmbeddedChannel getEmbeddedChannel() {
- return embeddedChannel;
- }
-
- @Nullable
- protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() {
- return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
- }
-
- protected abstract CipherStatePair doHandshake() throws Throwable;
-
- /**
- * Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no
- * waiting messages in the channel, return null.
- */
- byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
- final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
- if (responseFrame == null) {
- return null;
- }
- final byte[] plaintext = new byte[responseFrame.readableBytes() - 16];
- final int read = clientCipherPair.getReceiver().decryptWithAd(null,
- ByteBufUtil.getBytes(responseFrame), 0,
- plaintext, 0,
- responseFrame.readableBytes());
- assertEquals(read, plaintext.length);
- return plaintext;
- }
-
-
- @Test
- void handleInvalidInitialMessage() throws InterruptedException {
- final byte[] contentBytes = new byte[17];
- ThreadLocalRandom.current().nextBytes(contentBytes);
-
- final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
-
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await();
-
- assertFalse(writeFuture.isSuccess());
- assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
- assertEquals(0, content.refCnt());
- assertNull(getNoiseHandshakeCompleteEvent());
- }
-
- @Test
- void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
- final ByteBuf[] frames = new ByteBuf[7];
-
- for (int i = 0; i < frames.length; i++) {
- final byte[] contentBytes = new byte[17];
- ThreadLocalRandom.current().nextBytes(contentBytes);
-
- frames[i] = Unpooled.wrappedBuffer(contentBytes);
-
- embeddedChannel.writeOneInbound(frames[i]).await();
- }
-
- for (final ByteBuf frame : frames) {
- assertEquals(0, frame.refCnt());
- }
-
- assertNull(getNoiseHandshakeCompleteEvent());
- }
-
- @Test
- void handleNonByteBufBinaryFrame() throws Throwable {
- final byte[] contentBytes = new byte[17];
- ThreadLocalRandom.current().nextBytes(contentBytes);
-
- final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
-
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
-
- assertFalse(writeFuture.isSuccess());
- assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
- assertEquals(0, message.refCnt());
- assertNull(getNoiseHandshakeCompleteEvent());
-
- assertTrue(embeddedChannel.inboundMessages().isEmpty());
- }
-
- @Test
- void channelRead() throws Throwable {
- final CipherStatePair clientCipherStatePair = doHandshake();
- final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8);
- final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
- clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
-
- final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext);
- assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
- assertEquals(0, ciphertextFrame.refCnt());
-
- final byte[] response = readNextPlaintext(clientCipherStatePair);
- assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response);
- }
-
- @Test
- void channelReadBadCiphertext() throws Throwable {
- doHandshake();
- final byte[] bogusCiphertext = new byte[32];
- io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
-
- final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext);
- final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
-
- assertEquals(0, ciphertextFrame.refCnt());
- assertFalse(readCiphertextFuture.isSuccess());
- assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
- assertTrue(embeddedChannel.inboundMessages().isEmpty());
- }
-
- @Test
- void channelReadUnexpectedMessageType() throws Throwable {
- doHandshake();
- final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
-
- assertFalse(readFuture.isSuccess());
- assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
- assertTrue(embeddedChannel.inboundMessages().isEmpty());
- }
-
- @Test
- void write() throws Throwable {
- final CipherStatePair clientCipherStatePair = doHandshake();
- final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
- final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
-
- final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
- assertTrue(writePlaintextFuture.await().isSuccess());
- assertEquals(0, plaintextBuffer.refCnt());
-
- final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
- assertNotNull(ciphertextFrame);
- assertTrue(embeddedChannel.outboundMessages().isEmpty());
-
- final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
- ciphertextFrame.release();
-
- final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
- clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
-
- assertArrayEquals(plaintext, decryptedPlaintext);
- }
-
- @Test
- void writeUnexpectedMessageType() throws Throwable {
- doHandshake();
- final Object unexpectedMessaged = new Object();
-
- final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
- assertTrue(writeFuture.await().isSuccess());
-
- assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
- assertTrue(embeddedChannel.outboundMessages().isEmpty());
- }
-
- @ParameterizedTest
- @ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
- void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
- final CipherStatePair clientCipherStatePair = doHandshake();
- final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
- final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));
-
- final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
- assertTrue(writePlaintextFuture.isSuccess());
-
- final byte[] decryptedPlaintext = new byte[plaintextLength];
- int plaintextOffset = 0;
- ByteBuf ciphertextFrame;
- while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) {
- assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN);
- final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
- ciphertextFrame.release();
- plaintextOffset += clientCipherStatePair.getReceiver()
- .decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
- }
- assertArrayEquals(plaintext, decryptedPlaintext);
- assertEquals(0, plaintextBuffer.refCnt());
-
- }
-
- @Test
- public void writeHugeInboundMessage() throws Throwable {
- doHandshake();
- final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
- embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
- assertThrows(NoiseException.class, embeddedChannel::checkException);
- }
-
- @Test
- public void channelAttributes() throws Throwable {
- doHandshake();
- final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent();
- assertEquals(REMOTE_ADDRESS, event.remoteAddress());
- assertEquals(USER_AGENT, event.userAgent());
- assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage());
- }
-
- protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() {
- return NoiseTunnelProtos.HandshakeInit.newBuilder()
- .setUserAgent(USER_AGENT)
- .setAcceptLanguage(ACCEPT_LANGUAGE);
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java
deleted file mode 100644
index 4d1eeda6b..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java
+++ /dev/null
@@ -1,451 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyByte;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import com.google.protobuf.ByteString;
-import io.grpc.ManagedChannel;
-import io.grpc.ServerBuilder;
-import io.grpc.Status;
-import io.grpc.netty.NettyChannelBuilder;
-import io.grpc.stub.StreamObserver;
-import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.local.LocalChannel;
-import io.netty.channel.nio.NioEventLoopGroup;
-import io.netty.handler.codec.haproxy.HAProxyCommand;
-import io.netty.handler.codec.haproxy.HAProxyMessage;
-import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
-import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
-import java.util.Optional;
-import java.util.UUID;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executor;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.function.Supplier;
-import org.apache.commons.lang3.RandomStringUtils;
-import org.junit.jupiter.api.AfterAll;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.ValueSource;
-import org.signal.chat.rpc.EchoRequest;
-import org.signal.chat.rpc.EchoResponse;
-import org.signal.chat.rpc.EchoServiceGrpc;
-import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
-import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
-import org.signal.chat.rpc.GetRequestAttributesRequest;
-import org.signal.chat.rpc.RequestAttributesGrpc;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.signal.libsignal.protocol.ecc.ECPublicKey;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
-import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
-import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
-import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
-import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
-import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
-import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
-import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
-import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
-import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
-import org.whispersystems.textsecuregcm.storage.Device;
-import org.whispersystems.textsecuregcm.util.UUIDUtil;
-
-public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
-
- private static NioEventLoopGroup nioEventLoopGroup;
- private static DefaultEventLoopGroup defaultEventLoopGroup;
- private static ExecutorService delegatedTaskExecutor;
- private static ExecutorService serverCallExecutor;
-
- private GrpcClientConnectionManager grpcClientConnectionManager;
- private ClientPublicKeysManager clientPublicKeysManager;
-
- private ECKeyPair serverKeyPair;
- private ECKeyPair clientKeyPair;
-
- private ManagedLocalGrpcServer authenticatedGrpcServer;
- private ManagedLocalGrpcServer anonymousGrpcServer;
-
- private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
- private static final byte DEVICE_ID = Device.PRIMARY_ID;
-
- public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
-
- @BeforeAll
- static void setUpBeforeAll() {
- nioEventLoopGroup = new NioEventLoopGroup();
- defaultEventLoopGroup = new DefaultEventLoopGroup();
- delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
- serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
- }
-
- @BeforeEach
- void setUp() throws Exception {
-
- clientKeyPair = ECKeyPair.generate();
- serverKeyPair = ECKeyPair.generate();
-
- grpcClientConnectionManager = new GrpcClientConnectionManager();
-
- clientPublicKeysManager = mock(ClientPublicKeysManager.class);
- when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
- .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
-
- when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
-
- final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
- final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
-
- authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
- @Override
- protected void configureServer(final ServerBuilder> serverBuilder) {
- serverBuilder
- .executor(serverCallExecutor)
- .addService(new RequestAttributesServiceImpl())
- .addService(new EchoServiceImpl())
- .intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
- .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
- .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
- }
- };
-
- authenticatedGrpcServer.start();
-
- anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
- @Override
- protected void configureServer(final ServerBuilder> serverBuilder) {
- serverBuilder
- .executor(serverCallExecutor)
- .addService(new RequestAttributesServiceImpl())
- .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
- .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
- }
- };
-
- anonymousGrpcServer.start();
- this.start(
- nioEventLoopGroup,
- delegatedTaskExecutor,
- grpcClientConnectionManager,
- clientPublicKeysManager,
- serverKeyPair,
- authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
- RECOGNIZED_PROXY_SECRET);
- }
-
-
- protected abstract void start(
- final NioEventLoopGroup eventLoopGroup,
- final Executor delegatedTaskExecutor,
- final GrpcClientConnectionManager grpcClientConnectionManager,
- final ClientPublicKeysManager clientPublicKeysManager,
- final ECKeyPair serverKeyPair,
- final LocalAddress authenticatedGrpcServerAddress,
- final LocalAddress anonymousGrpcServerAddress,
- final String recognizedProxySecret) throws Exception;
- protected abstract void stop() throws Exception;
- protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey);
-
- public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason)
- throws ExecutionException, InterruptedException, TimeoutException {
- final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
- assertEquals(reason, result.closeReason());
- }
-
- @AfterEach
- void tearDown() throws Exception {
- authenticatedGrpcServer.stop();
- anonymousGrpcServer.stop();
- this.stop();
- }
-
- @AfterAll
- static void tearDownAfterAll() throws InterruptedException {
- nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
- defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
-
- delegatedTaskExecutor.shutdown();
- //noinspection ResultOfMethodCallIgnored
- delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
-
- serverCallExecutor.shutdown();
- //noinspection ResultOfMethodCallIgnored
- serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
- }
-
- @ParameterizedTest
- @ValueSource(booleans = {true, false})
- void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
- try (final NoiseTunnelClient client = authenticated()
- .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
- .build()) {
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
-
- assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
- assertEquals(DEVICE_ID, response.getDeviceId());
- } finally {
- channel.shutdown();
- }
- }
- }
-
- @Test
- void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
-
- // Try to verify the server's public key with something other than the key with which it was signed
- try (final NoiseTunnelClient client = authenticated()
- .setServerPublicKey(ECKeyPair.generate().getPublicKey())
- .build()) {
-
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- //noinspection ResultOfMethodCallIgnored
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
- () -> RequestAttributesGrpc.newBlockingStub(channel)
- .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
- } finally {
- channel.shutdown();
- }
- assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
- }
- }
-
- @Test
- void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException {
-
- when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
-
- try (final NoiseTunnelClient client = authenticated().build()) {
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- //noinspection ResultOfMethodCallIgnored
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
- () -> RequestAttributesGrpc.newBlockingStub(channel)
- .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
- } finally {
- channel.shutdown();
- }
- assertEquals(
- NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
- client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
- assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
- }
- }
-
- @Test
- void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException {
- when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
- .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
-
- try (final NoiseTunnelClient client = authenticated().build()) {
-
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- //noinspection ResultOfMethodCallIgnored
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
- () -> RequestAttributesGrpc.newBlockingStub(channel)
- .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
- } finally {
- channel.shutdown();
- }
- assertEquals(
- NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
- client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
- assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
- }
-
- }
-
- @Test
- void clientNormalClosure() throws InterruptedException {
- final NoiseTunnelClient client = anonymous().build();
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
- try {
- final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
-
- assertTrue(response.getAccountIdentifier().isEmpty());
- assertEquals(0, response.getDeviceId());
- client.close();
-
- // When we gracefully close the tunnel client, we should send an OK close frame
- final CloseFrameEvent closeFrame = client.closeFrameFuture().join();
- assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator());
- assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason());
- } finally {
- channel.shutdown();
- }
- }
-
- @Test
- void connectAnonymous() throws InterruptedException {
- try (final NoiseTunnelClient client = anonymous().build()) {
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
-
- assertTrue(response.getAccountIdentifier().isEmpty());
- assertEquals(0, response.getDeviceId());
- } finally {
- channel.shutdown();
- }
- }
- }
-
- @Test
- void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
-
- // Try to verify the server's public key with something other than the key with which it was signed
- try (final NoiseTunnelClient client = anonymous()
- .setServerPublicKey(ECKeyPair.generate().getPublicKey())
- .build()) {
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- //noinspection ResultOfMethodCallIgnored
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
- () -> RequestAttributesGrpc.newBlockingStub(channel)
- .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
- } finally {
- channel.shutdown();
- }
- assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
- }
-
- }
-
- protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
- return NettyChannelBuilder.forAddress(localAddress)
- .channelType(LocalChannel.class)
- .eventLoopGroup(defaultEventLoopGroup)
- .usePlaintext()
- .build();
- }
-
-
- @Test
- void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException {
-
- try (final NoiseTunnelClient client = authenticated().build()) {
-
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
-
- assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
- assertEquals(DEVICE_ID, response.getDeviceId());
-
- grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
- final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS);
- assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason());
- assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator());
- } finally {
- channel.shutdown();
- }
- }
- }
-
- @Test
- void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException {
- try (final NoiseTunnelClient client = authenticated().build()) {
-
- final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
-
- try {
- final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
- .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
-
- assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
- assertEquals(DEVICE_ID, response.getDeviceId());
-
- final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
-
- // Start an open-ended server call and leave it in a non-complete state
- final StreamObserver echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
- new StreamObserver<>() {
- @Override
- public void onNext(final EchoResponse echoResponse) {
- responseCountDownLatch.countDown();
- }
-
- @Override
- public void onError(final Throwable throwable) {
- }
-
- @Override
- public void onCompleted() {
- }
- });
-
- // Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
- // the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
- // truly started before requesting connection closure.
- echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
- assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
-
- grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
- try {
- client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS);
- fail("Channel should not close until active requests have finished");
- } catch (TimeoutException e) {
- }
-
- //noinspection ResultOfMethodCallIgnored
- GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
- .echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
-
- // Complete the open-ended server call
- echoRequestStreamObserver.onCompleted();
-
- final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
- assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator());
- assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason());
- } finally {
- channel.shutdown();
- }
- }
- }
-
- protected NoiseTunnelClient.Builder anonymous() {
- return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey());
- }
-
- protected NoiseTunnelClient.Builder authenticated() {
- return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey())
- .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID);
- }
-
- private static Supplier proxyMessageSupplier(boolean includeProxyMesage) {
- return includeProxyMesage
- ? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
- "10.0.0.1", "10.0.0.2", 12345, 443)
- : null;
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java
deleted file mode 100644
index 654ac5d72..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java
+++ /dev/null
@@ -1,227 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-import com.google.common.net.InetAddresses;
-import io.netty.bootstrap.Bootstrap;
-import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelInitializer;
-import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.EventLoopGroup;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.channel.local.LocalAddress;
-import io.netty.channel.local.LocalChannel;
-import io.netty.channel.local.LocalServerChannel;
-import java.net.InetAddress;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.Optional;
-import java.util.UUID;
-import org.junit.jupiter.api.AfterAll;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.Arguments;
-import org.junit.jupiter.params.provider.MethodSource;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
-import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
-import org.whispersystems.textsecuregcm.storage.Device;
-
-class GrpcClientConnectionManagerTest {
-
- private static EventLoopGroup eventLoopGroup;
-
- private LocalChannel localChannel;
- private LocalChannel remoteChannel;
-
- private LocalServerChannel localServerChannel;
-
- private GrpcClientConnectionManager grpcClientConnectionManager;
-
- @BeforeAll
- static void setUpBeforeAll() {
- eventLoopGroup = new DefaultEventLoopGroup();
- }
-
- @BeforeEach
- void setUp() throws InterruptedException {
- eventLoopGroup = new DefaultEventLoopGroup();
-
- grpcClientConnectionManager = new GrpcClientConnectionManager();
-
- // We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
- // local server to which we can open trivial local connections
- localServerChannel = (LocalServerChannel) new ServerBootstrap()
- .group(eventLoopGroup)
- .channel(LocalServerChannel.class)
- .childHandler(new ChannelInitializer<>() {
- @Override
- protected void initChannel(final Channel channel) {
- }
- })
- .bind(new LocalAddress("test-server"))
- .await()
- .channel();
-
- final Bootstrap clientBootstrap = new Bootstrap()
- .group(eventLoopGroup)
- .channel(LocalChannel.class)
- .handler(new ChannelInitializer<>() {
- @Override
- protected void initChannel(final Channel ch) {
- }
- });
-
- localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
- remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
- }
-
- @AfterEach
- void tearDown() throws InterruptedException {
- localChannel.close().await();
- remoteChannel.close().await();
- localServerChannel.close().await();
- }
-
- @AfterAll
- static void tearDownAfterAll() throws InterruptedException {
- eventLoopGroup.shutdownGracefully().await();
- }
-
- @ParameterizedTest
- @MethodSource
- void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) {
- grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
-
- assertEquals(maybeAuthenticatedDevice,
- grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
- }
-
- private static List> getAuthenticatedDevice() {
- return List.of(
- Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
- Optional.empty()
- );
- }
-
- @Test
- void getRequestAttributes() {
- grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
-
- assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
-
- final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
- remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
-
- assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
- }
-
- @Test
- void closeConnection() throws InterruptedException, ChannelNotFoundException {
- final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
-
- grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
-
- assertTrue(remoteChannel.isOpen());
-
- assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
- assertEquals(List.of(remoteChannel),
- grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
-
- remoteChannel.close().await();
-
- assertThrows(ChannelNotFoundException.class,
- () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
-
- assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
- }
-
- @ParameterizedTest
- @MethodSource
- void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress,
- final String userAgentHeader,
- final String acceptLanguageHeader,
- final RequestAttributes expectedRequestAttributes) {
-
- final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
-
- GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel,
- preferredRemoteAddress,
- userAgentHeader,
- acceptLanguageHeader);
-
- assertEquals(expectedRequestAttributes,
- embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
- }
-
- private static List handleHandshakeInitiatedRequestAttributes() {
- final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
-
- return List.of(
- Arguments.argumentSet("Null User-Agent and Accept-Language headers",
- preferredRemoteAddress, null, null,
- new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
-
- Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
- preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
- new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
-
- Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
- preferredRemoteAddress, "Not a valid user-agent string", null,
- new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
-
- Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
- preferredRemoteAddress, null, "ja,en;q=0.4",
- new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
-
- Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
- preferredRemoteAddress, null, "This is not a valid language preference list",
- new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
- );
- }
-
- @Test
- void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
- final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
-
- assertThrows(ChannelNotFoundException.class,
- () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
-
- assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
-
- grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
-
- assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
- assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
-
- remoteChannel.close().await();
-
- assertThrows(ChannelNotFoundException.class,
- () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
-
- assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
- }
-
- @Test
- void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
- assertThrows(ChannelNotFoundException.class,
- () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
-
- grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
-
- assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
-
- remoteChannel.close().await();
-
- assertThrows(ChannelNotFoundException.class,
- () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java
deleted file mode 100644
index de26f3c82..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java
+++ /dev/null
@@ -1,62 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.handler.codec.haproxy.HAProxyCommand;
-import io.netty.handler.codec.haproxy.HAProxyMessage;
-import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
-import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
-import java.util.concurrent.ThreadLocalRandom;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-
-class HAProxyMessageHandlerTest {
-
- private EmbeddedChannel embeddedChannel;
-
- @BeforeEach
- void setUp() {
- embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler());
- }
-
- @Test
- void handleHAProxyMessage() throws InterruptedException {
- final HAProxyMessage haProxyMessage = new HAProxyMessage(
- HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
- "10.0.0.1", "10.0.0.2", 12345, 443);
-
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage);
- embeddedChannel.flushInbound();
-
- writeFuture.await();
-
- assertTrue(embeddedChannel.inboundMessages().isEmpty());
- assertEquals(0, haProxyMessage.refCnt());
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- }
-
- @Test
- void handleNonHAProxyMessage() throws InterruptedException {
- final byte[] bytes = new byte[32];
- ThreadLocalRandom.current().nextBytes(bytes);
-
- final ByteBuf message = Unpooled.wrappedBuffer(bytes);
-
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message);
- embeddedChannel.flushInbound();
-
- writeFuture.await();
-
- assertEquals(1, embeddedChannel.inboundMessages().size());
- assertEquals(message, embeddedChannel.inboundMessages().poll());
- assertEquals(1, message.refCnt());
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java
deleted file mode 100644
index 0a775ccbf..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java
+++ /dev/null
@@ -1,116 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-import com.google.protobuf.ByteString;
-import com.southernstorm.noise.protocol.CipherStatePair;
-import com.southernstorm.noise.protocol.HandshakeState;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.embedded.EmbeddedChannel;
-import java.util.Optional;
-import javax.crypto.BadPaddingException;
-import javax.crypto.ShortBufferException;
-import org.junit.jupiter.api.Test;
-
-class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
-
- @Override
- protected CipherStatePair doHandshake() throws Exception {
- return doHandshake(baseHandshakeInit().build().toByteArray());
- }
-
- private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
-
- final HandshakeState clientHandshakeState =
- new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
-
- clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
- clientHandshakeState.start();
-
- // Send initiator handshake message
-
- // 32 byte key, request payload, 16 byte AEAD tag
- final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16;
- final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength];
- assertEquals(
- initiateHandshakeMessageLength,
- clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
- final NoiseHandshakeInit message = new NoiseHandshakeInit(
- REMOTE_ADDRESS,
- HandshakePattern.NK,
- Unpooled.wrappedBuffer(initiateHandshakeMessage));
- assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess());
- assertEquals(0, message.refCnt());
-
- embeddedChannel.runPendingTasks();
-
- // Read responder handshake message
- assertFalse(embeddedChannel.outboundMessages().isEmpty());
- final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
- assertNotNull(responderHandshakeFrame);
- final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame);
-
- final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder()
- .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
- .build();
-
- // ephemeral key, payload, AEAD tag
- assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length);
-
- final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()];
- assertEquals(expectedHandshakeResponse.getSerializedSize(),
- clientHandshakeState.readMessage(
- responderHandshakeBytes, 0, responderHandshakeBytes.length,
- handshakeResponsePlaintext, 0));
-
- assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext));
-
- final byte[] serverPublicKey = new byte[32];
- clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
- assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
-
- return clientHandshakeState.split();
- }
-
- @Test
- void handleCompleteHandshakeWithRequest() throws Exception {
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final byte[] handshakePlaintext = baseHandshakeInit()
- .setFastOpenRequest(ByteString.copyFromUtf8("ping")).build()
- .toByteArray();
-
- final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext);
- final byte[] response = readNextPlaintext(cipherStatePair);
- assertArrayEquals(response, "pong".getBytes());
-
- assertEquals(
- new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
- getNoiseHandshakeCompleteEvent());
- }
-
- @Test
- void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake());
- assertNull(readNextPlaintext(cipherStatePair));
-
- assertEquals(
- new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
- getNoiseHandshakeCompleteEvent());
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java
deleted file mode 100644
index 225800963..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java
+++ /dev/null
@@ -1,338 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertInstanceOf;
-import static org.junit.jupiter.api.Assertions.assertNotNull;
-import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.mockito.Mockito.verifyNoInteractions;
-import static org.mockito.Mockito.when;
-
-import com.google.protobuf.ByteString;
-import com.southernstorm.noise.protocol.CipherStatePair;
-import com.southernstorm.noise.protocol.HandshakeState;
-import com.southernstorm.noise.protocol.Noise;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.util.internal.EmptyArrays;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.security.NoSuchAlgorithmException;
-import java.util.Optional;
-import java.util.UUID;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ThreadLocalRandom;
-import javax.crypto.BadPaddingException;
-import javax.crypto.ShortBufferException;
-import org.junit.jupiter.api.Test;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.signal.libsignal.protocol.ecc.ECPublicKey;
-import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
-import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
-import org.whispersystems.textsecuregcm.storage.Device;
-import org.whispersystems.textsecuregcm.util.TestRandomUtil;
-import org.whispersystems.textsecuregcm.util.UUIDUtil;
-
-class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
-
- private final ECKeyPair clientKeyPair = ECKeyPair.generate();
-
- @Override
- protected CipherStatePair doHandshake() throws Throwable {
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
- return doHandshake(identityPayload(accountIdentifier, deviceId));
- }
-
- @Test
- void handleCompleteHandshakeNoInitialRequest() throws Throwable {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
-
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
-
- assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
-
- assertEquals(
- new NoiseIdentityDeterminedEvent(
- Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
- REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
- getNoiseHandshakeCompleteEvent());
- }
-
- @Test
- void handleCompleteHandshakeWithInitialRequest() throws Throwable {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
-
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
-
- final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId)
- .setFastOpenRequest(ByteString.copyFromUtf8("ping"))
- .build()
- .toByteArray();
-
- final byte[] response = readNextPlaintext(doHandshake(handshakeInit));
- assertEquals(4, response.length);
- assertEquals("pong", new String(response));
-
- assertEquals(
- new NoiseIdentityDeterminedEvent(
- Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
- REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
- getNoiseHandshakeCompleteEvent());
- }
-
- @Test
- void handleCompleteHandshakeMissingIdentityInformation() {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
-
- verifyNoInteractions(clientPublicKeysManager);
-
- assertNull(getNoiseHandshakeCompleteEvent());
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
- "Handshake handler should not remove self from pipeline after failed handshake");
-
- assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
- "Noise stream handler should not be added to pipeline after failed handshake");
- }
-
- @Test
- void handleCompleteHandshakeMalformedIdentityInformation() {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- // no deviceId byte
- byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
- assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload));
-
- verifyNoInteractions(clientPublicKeysManager);
-
- assertNull(getNoiseHandshakeCompleteEvent());
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
- "Handshake handler should not remove self from pipeline after failed handshake");
-
- assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
- "Noise stream handler should not be added to pipeline after failed handshake");
- }
-
- @Test
- void handleCompleteHandshakeUnrecognizedDevice() throws Throwable {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
-
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
-
- doHandshake(
- identityPayload(accountIdentifier, deviceId),
- NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
-
- assertNull(getNoiseHandshakeCompleteEvent());
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
- "Handshake handler should not remove self from pipeline after failed handshake");
-
- assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
- "Noise stream handler should not be added to pipeline after failed handshake");
- }
-
- @Test
- void handleCompleteHandshakePublicKeyMismatch() throws Throwable {
-
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
-
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
-
- doHandshake(
- identityPayload(accountIdentifier, deviceId),
- NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
-
- assertNull(getNoiseHandshakeCompleteEvent());
-
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
- "Handshake handler should not remove self from pipeline after failed handshake");
- }
-
- @Test
- void handleInvalidExtraWrites()
- throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
- assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
-
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
-
- final HandshakeState clientHandshakeState = clientHandshakeState();
-
- final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>();
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
-
- final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit(
- REMOTE_ADDRESS,
- HandshakePattern.IK,
- Unpooled.wrappedBuffer(
- initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
- assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess());
-
- // While waiting for the public key, send another message
- final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
- assertInstanceOf(IllegalArgumentException.class, f.exceptionNow());
-
- findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
- embeddedChannel.runPendingTasks();
- }
-
- @Test
- public void handleOversizeHandshakeMessage() {
- final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
- ByteBuffer.wrap(big)
- .put(UUIDUtil.toBytes(UUID.randomUUID()))
- .put((byte) 0x01);
- assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
- }
-
- @Test
- public void handleKeyLookupError() {
- final UUID accountIdentifier = UUID.randomUUID();
- final byte deviceId = randomDeviceId();
- when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
- .thenReturn(CompletableFuture.failedFuture(new IOException()));
- assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
- }
-
- private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
- final HandshakeState clientHandshakeState =
- new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
-
- clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
- clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
- clientHandshakeState.start();
- return clientHandshakeState;
- }
-
- private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload)
- throws ShortBufferException {
- // Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag
- final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16];
- int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length);
- assertEquals(written, initiatorMessageBytes.length);
- return initiatorMessageBytes;
- }
-
- private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message)
- throws ShortBufferException, BadPaddingException {
-
- // 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload
- final int expectedResponsePayloadLength = message.length - 32 - 16;
- final byte[] responsePayload = new byte[expectedResponsePayloadLength];
- final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0);
- assertEquals(expectedResponsePayloadLength, responsePayloadLength);
- return responsePayload;
- }
-
- private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
- return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK);
- }
-
- private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable {
- final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
-
- final HandshakeState clientHandshakeState = clientHandshakeState();
- final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
-
- final NoiseHandshakeInit initMessage = new NoiseHandshakeInit(
- REMOTE_ADDRESS,
- HandshakePattern.IK,
- Unpooled.wrappedBuffer(initiatorMessage));
- final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await();
- assertEquals(0, initMessage.refCnt());
- if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
- throw await.cause();
- }
-
- // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
- // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
- // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
- // and issue a "handshake complete" event.
- embeddedChannel.runPendingTasks();
-
- // rethrow if running the task caused an error, and the caller isn't expecting an error
- if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
- embeddedChannel.checkException();
- }
-
- assertFalse(embeddedChannel.outboundMessages().isEmpty());
-
- final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
- assertNotNull(handshakeResponseFrame);
- final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame);
-
- final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder()
- .setCode(expectedStatus)
- .build();
-
- final byte[] actualHandshakeResponsePlaintext =
- readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes);
-
- assertEquals(
- expectedHandshakeResponsePlaintext,
- NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext));
-
- final byte[] serverPublicKey = new byte[32];
- clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
- assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
-
- return clientHandshakeState.split();
- }
-
- private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) {
- return baseHandshakeInit()
- .setAci(UUIDUtil.toByteString(accountIdentifier))
- .setDeviceId(deviceId);
- }
-
- private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
- return identifiedHandshakeInit(accountIdentifier, deviceId)
- .build()
- .toByteArray();
- }
-
- private static byte randomDeviceId() {
- return (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1);
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java
deleted file mode 100644
index 13d26813c..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Copyright 2024 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatNoException;
-
-import com.southernstorm.noise.protocol.HandshakeState;
-import io.netty.buffer.ByteBuf;
-import java.nio.charset.StandardCharsets;
-import javax.crypto.ShortBufferException;
-import io.netty.buffer.ByteBufUtil;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.EnumSource;
-import org.signal.libsignal.protocol.ecc.ECKeyPair;
-import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper;
-
-
-public class NoiseHandshakeHelperTest {
-
- @ParameterizedTest
- @EnumSource(HandshakePattern.class)
- void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
- doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8));
- }
-
- @ParameterizedTest
- @EnumSource(HandshakePattern.class)
- void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
- doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]);
- }
-
- @ParameterizedTest
- @EnumSource(HandshakePattern.class)
- void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
- doHandshake(pattern, new byte[0], new byte[0]);
- }
-
- void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException {
- final ECKeyPair serverKeyPair = ECKeyPair.generate();
- final ECKeyPair clientKeyPair = ECKeyPair.generate();
-
- NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair);
- NoiseClientHandshakeHelper clientHelper = switch (pattern) {
- case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair);
- case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey());
- };
-
- final byte[] initiate = clientHelper.write(requestPayload);
- final ByteBuf actualRequestPayload = serverHelper.read(initiate);
- assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload);
-
- assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE);
-
- final byte[] respond = serverHelper.write(responsePayload);
- byte[] actualResponsePayload = clientHelper.read(respond);
- assertThat(actualResponsePayload).isEqualTo(responsePayload);
-
- assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT);
- assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split());
- assertThatNoException().isThrownBy(() -> clientHelper.split());
- }
-
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java
deleted file mode 100644
index d7d3751ac..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java
+++ /dev/null
@@ -1,108 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertInstanceOf;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufUtil;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.handler.codec.haproxy.HAProxyMessage;
-import java.util.HexFormat;
-import java.util.concurrent.ThreadLocalRandom;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-
-class ProxyProtocolDetectionHandlerTest {
-
- private EmbeddedChannel embeddedChannel;
-
- private static final byte[] PROXY_V2_MESSAGE_BYTES =
- HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb");
-
- @BeforeEach
- void setUp() {
- embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler());
- }
-
- @Test
- void singlePacketProxyMessage() throws InterruptedException {
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES));
- embeddedChannel.flushInbound();
-
- writeFuture.await();
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- assertEquals(1, embeddedChannel.inboundMessages().size());
- assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
- }
-
- @Test
- void multiPacketProxyMessage() throws InterruptedException {
- final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
- Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0,
- ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
-
- final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
- Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
- PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
-
- embeddedChannel.flushInbound();
-
- firstWriteFuture.await();
- secondWriteFuture.await();
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- assertEquals(1, embeddedChannel.inboundMessages().size());
- assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
- }
-
- @Test
- void singlePacketNonProxyMessage() throws InterruptedException {
- final byte[] nonProxyProtocolMessage = new byte[32];
- ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
-
- final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage));
- embeddedChannel.flushInbound();
-
- writeFuture.await();
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- assertEquals(1, embeddedChannel.inboundMessages().size());
-
- final Object inboundMessage = embeddedChannel.inboundMessages().poll();
-
- assertInstanceOf(ByteBuf.class, inboundMessage);
- assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
- }
-
- @Test
- void multiPacketNonProxyMessage() throws InterruptedException {
- final byte[] nonProxyProtocolMessage = new byte[32];
- ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
-
- final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
- Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0,
- ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
-
- final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
- Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
- nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
-
- embeddedChannel.flushInbound();
-
- firstWriteFuture.await();
- secondWriteFuture.await();
-
- assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
- assertEquals(1, embeddedChannel.inboundMessages().size());
-
- final Object inboundMessage = embeddedChannel.inboundMessages().poll();
-
- assertInstanceOf(ByteBuf.class, inboundMessage);
- assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java
deleted file mode 100644
index 0b6e2c0ac..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java
+++ /dev/null
@@ -1,16 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.client;
-
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class ClientErrorHandler extends ChannelInboundHandlerAdapter {
- private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class);
-
- @Override
- public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
- log.error("Caught inbound error in client; closing connection", cause);
- context.channel().close();
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java
deleted file mode 100644
index 13c8b9924..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright 2025 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.grpc.net.client;
-
-import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
-import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
-
-public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
-
- public enum CloseReason {
- OK,
- SERVER_CLOSED,
- NOISE_ERROR,
- NOISE_HANDSHAKE_ERROR,
- INTERNAL_SERVER_ERROR,
- UNKNOWN
- }
-
- public enum CloseInitiator {
- SERVER,
- CLIENT
- }
-
- public static CloseFrameEvent fromWebsocketCloseFrame(
- CloseWebSocketFrame closeWebSocketFrame,
- CloseInitiator closeInitiator) {
- final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
- case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
- case 4002 -> CloseReason.NOISE_ERROR;
- case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
- case 1012 -> CloseReason.SERVER_CLOSED;
- case 1000 -> CloseReason.OK;
- default -> CloseReason.UNKNOWN;
- };
- return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
- }
-
- public static CloseFrameEvent fromNoiseDirectCloseFrame(
- NoiseDirectProtos.CloseReason noiseDirectCloseReason,
- CloseInitiator closeInitiator) {
- final CloseReason code = switch (noiseDirectCloseReason.getCode()) {
- case OK -> CloseReason.OK;
- case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
- case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
- case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
- case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
- case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
- };
- return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage());
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java
deleted file mode 100644
index c735902b6..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java
+++ /dev/null
@@ -1,120 +0,0 @@
-package org.whispersystems.textsecuregcm.grpc.net.client;
-
-import io.netty.bootstrap.Bootstrap;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandler;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.channel.ChannelInitializer;
-import io.netty.channel.socket.SocketChannel;
-import io.netty.channel.socket.nio.NioSocketChannel;
-import io.netty.util.ReferenceCountUtil;
-import java.net.SocketAddress;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
-import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
-
-/**
- * Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
- * gRPC server.
- *
- * This handler waits until the first gRPC client message is ready and then establishes a connection with the remote
- * gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote
- * connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the
- * handshake is finished.
- */
-class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
-
- private final List remoteHandlerStack;
- private final NoiseTunnelProtos.HandshakeInit handshakeInit;
-
- private final SocketAddress remoteServerAddress;
- // If provided, will be sent with the payload in the noise handshake
-
- private final List