Add direct grpc server

This commit is contained in:
ravi-signal
2025-10-06 15:22:36 -05:00
committed by GitHub
parent 9751569dc7
commit a2f2fc93b0
87 changed files with 546 additions and 6469 deletions

View File

@@ -20,8 +20,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -30,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class DisconnectionRequestManagerTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private DisconnectionRequestManager disconnectionRequestManager;
@RegisterExtension
@@ -38,10 +35,8 @@ class DisconnectionRequestManagerTest {
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(),
grpcClientConnectionManager,
Runnable::run,
mock(ScheduledExecutorService.class));
@@ -103,16 +98,8 @@ class DisconnectionRequestManagerTest {
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(grpcClientConnectionManager, timeout(1_000))
.closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
verify(grpcClientConnectionManager, timeout(1_000))
.closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
disconnectionRequestManager.requestDisconnection(otherAccountIdentifier, List.of(otherDeviceId));
verify(grpcClientConnectionManager, timeout(1_000))
.closeConnection(new AuthenticatedDevice(otherAccountIdentifier, otherDeviceId));
}
@Test
@@ -141,11 +128,5 @@ class DisconnectionRequestManagerTest {
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
verify(grpcClientConnectionManager, timeout(1_000))
.closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
verify(grpcClientConnectionManager, timeout(1_000))
.closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
}
}

View File

@@ -1,77 +0,0 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.mockito.Mockito.mock;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptorTest {
private static DefaultEventLoopGroup eventLoopGroup;
private GrpcClientConnectionManager grpcClientConnectionManager;
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(getInterceptor())
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
protected abstract AbstractAuthenticationInterceptor getInterceptor();
protected GrpcClientConnectionManager getClientConnectionManager() {
return grpcClientConnectionManager;
}
protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.CallCredentials;
import io.grpc.Metadata;
import io.grpc.Status;
import java.util.concurrent.Executor;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
public class BasicAuthCallCredentials extends CallCredentials {
private final String username;
private final String password;
public BasicAuthCallCredentials(String username, String password) {
this.username = username;
this.password = password;
}
@Override
public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor,
final MetadataApplier applier) {
try {
Metadata headers = new Metadata();
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER),
HeaderUtils.basicAuthHeader(username, password));
applier.apply(headers);
} catch (Exception e) {
applier.fail(Status.UNAUTHENTICATED.withCause(e));
}
}
}

View File

@@ -1,44 +1,64 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
class ProhibitAuthenticationInterceptorTest {
private Server server;
private ManagedChannel channel;
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
@BeforeEach
void setUp() throws Exception {
server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
.directExecutor()
.intercept(new ProhibitAuthenticationInterceptor())
.addService(new EchoServiceImpl())
.build()
.start();
channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
.directExecutor()
.build();
}
@AfterEach
void tearDown() throws Exception {
channel.shutdownNow();
server.shutdownNow();
channel.awaitTermination(5, TimeUnit.SECONDS);
server.awaitTermination(5, TimeUnit.SECONDS);
}
@Test
void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
void hasAuth() {
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc
.newBlockingStub(channel)
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
() -> client.echo(EchoRequest.getDefaultInstance()));
assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
}
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
@Test
void noAuth() {
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
assertDoesNotThrow(() -> client.echo(EchoRequest.getDefaultInstance()));
}
}

View File

@@ -1,44 +1,102 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import java.time.Instant;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
class RequireAuthenticationInterceptorTest {
private Server server;
private ManagedChannel channel;
private AccountAuthenticator authenticator;
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new RequireAuthenticationInterceptor(getClientConnectionManager());
@BeforeEach
void setUp() throws Exception {
authenticator = mock(AccountAuthenticator.class);
server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
.directExecutor()
.intercept(new RequireAuthenticationInterceptor(authenticator))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
.directExecutor()
.build();
}
@AfterEach
void tearDown() throws Exception {
channel.shutdownNow();
server.shutdownNow();
channel.awaitTermination(5, TimeUnit.SECONDS);
server.awaitTermination(5, TimeUnit.SECONDS);
}
@Test
void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
void hasAuth() {
final UUID aci = UUID.randomUUID();
final byte deviceId = 2;
when(authenticator.authenticate(eq(new BasicCredentials("test", "password"))))
.thenReturn(Optional.of(
new org.whispersystems.textsecuregcm.auth.AuthenticatedDevice(aci, deviceId, Instant.now())));
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
.newBlockingStub(channel)
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
final GetAuthenticatedDeviceResponse authenticatedDevice = client.getAuthenticatedDevice(
GetAuthenticatedDeviceRequest.getDefaultInstance());
assertEquals(authenticatedDevice.getDeviceId(), deviceId);
assertEquals(UUIDUtil.fromByteString(authenticatedDevice.getAccountIdentifier()), aci);
}
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@Test
void badCredentials() {
when(authenticator.authenticate(any())).thenReturn(Optional.empty());
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
.newBlockingStub(channel)
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
() -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
}
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
@Test
void missingCredentials() {
when(authenticator.authenticate(any())).thenReturn(Optional.empty());
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc.newBlockingStub(channel);
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
() -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
}
}

View File

@@ -1,88 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.Status;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
class ChannelShutdownInterceptorTest {
private GrpcClientConnectionManager grpcClientConnectionManager;
private ChannelShutdownInterceptor channelShutdownInterceptor;
private ServerCallHandler<String, String> nextCallHandler;
private static final Metadata HEADERS = new Metadata();
@BeforeEach
void setUp() {
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
//noinspection unchecked
nextCallHandler = mock(ServerCallHandler.class);
//noinspection unchecked
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
}
@Test
void interceptCallComplete() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onComplete();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallCancelled() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
final ServerCall.Listener<String> serverCallListener =
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
serverCallListener.onCancel();
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
verify(serverCall, never()).close(any(), any());
}
@Test
void interceptCallChannelClosing() {
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
}
}

View File

@@ -0,0 +1,131 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.MetadataUtils;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
public class RequestAttributesInterceptorTest {
private static String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)";
private Server server;
private AtomicBoolean removeUserAgent;
@BeforeEach
void setUp() throws Exception {
removeUserAgent = new AtomicBoolean(false);
server = NettyServerBuilder.forPort(0)
.directExecutor()
.intercept(new RequestAttributesInterceptor())
// the grpc client always inserts a user-agent if we don't set one, so to test missing UAs we remove the header
// on the server-side
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
if (removeUserAgent.get()) {
headers.removeAll(Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER));
}
return next.startCall(call, headers);
}
})
.addService(new RequestAttributesServiceImpl())
.build()
.start();
}
@AfterEach
void tearDown() throws Exception {
server.shutdownNow();
server.awaitTermination(1, TimeUnit.SECONDS);
}
private static List<Arguments> handleInvalidAcceptLanguage() {
return List.of(
Arguments.argumentSet("Null Accept-Language header", Optional.empty()),
Arguments.argumentSet("Empty Accept-Language header", Optional.of("")),
Arguments.argumentSet("Invalid Accept-Language header", Optional.of("This is not a valid language preference list")));
}
@ParameterizedTest
@MethodSource
void handleInvalidAcceptLanguage(Optional<String> acceptLanguageHeader) throws Exception {
final Metadata metadata = new Metadata();
acceptLanguageHeader.ifPresent(h -> metadata
.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), h));
final GetRequestAttributesResponse response = getRequestAttributes(metadata);
assertEquals(response.getAcceptableLanguagesCount(), 0);
}
@Test
void handleMissingUserAgent() throws InterruptedException {
removeUserAgent.set(true);
final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
assertEquals("", response.getUserAgent());
}
@Test
void allAttributes() throws InterruptedException {
final Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), "ja,en;q=0.4");
metadata.put(Metadata.Key.of("x-forwarded-for", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.3");
final GetRequestAttributesResponse response = getRequestAttributes(metadata);
assertTrue(response.getUserAgent().contains(USER_AGENT));
assertEquals("127.0.0.3", response.getRemoteAddress());
assertEquals(2, response.getAcceptableLanguagesCount());
assertEquals("ja", response.getAcceptableLanguages(0));
assertEquals("en;q=0.4", response.getAcceptableLanguages(1));
}
@Test
void useSocketAddrIfHeaderMissing() throws InterruptedException {
final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
assertEquals("127.0.0.1", response.getRemoteAddress());
}
private GetRequestAttributesResponse getRequestAttributes(Metadata metadata)
throws InterruptedException {
final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", server.getPort())
.directExecutor()
.usePlaintext()
.userAgent(USER_AGENT)
.build();
try {
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
.newBlockingStub(channel)
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
return client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance());
} finally {
channel.shutdownNow();
channel.awaitTermination(1, TimeUnit.SECONDS);
}
}
}

View File

@@ -1,21 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.util.ResourceLeakDetector;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
public abstract class AbstractLeakDetectionTest {
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
@BeforeAll
static void setLeakDetectionLevel() {
originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
}
@AfterAll
static void restoreLeakDetectionLevel() {
ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
}
}

View File

@@ -1,325 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
protected ECKeyPair serverKeyPair;
protected ClientPublicKeysManager clientPublicKeysManager;
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
private EmbeddedChannel embeddedChannel;
static final String USER_AGENT = "Test/User-Agent";
static final String ACCEPT_LANGUAGE = "test-lang";
static final InetAddress REMOTE_ADDRESS;
static {
try {
REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3});
} catch (UnknownHostException e) {
throw new RuntimeException(e);
}
}
private static class PongHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
try {
if (msg instanceof ByteBuf bb) {
if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) {
ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes()))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
} else {
throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb)));
}
} else {
throw new IllegalArgumentException("Unexpected message type: " + msg);
}
} finally {
ReferenceCountUtil.release(msg);
}
}
}
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
@Nullable
private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null;
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
handshakeCompleteEvent = noiseIdentityDeterminedEvent;
context.pipeline().addAfter(context.name(), null, new PongHandler());
context.pipeline().remove(NoiseHandshakeCompleteHandler.class);
} else {
context.fireUserEventTriggered(event);
}
}
@Nullable
public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() {
return handshakeCompleteEvent;
}
}
@BeforeEach
void setUp() {
serverKeyPair = ECKeyPair.generate();
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
embeddedChannel = new EmbeddedChannel(
new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair),
noiseHandshakeCompleteHandler);
}
@AfterEach
void tearDown() {
embeddedChannel.close();
}
protected EmbeddedChannel getEmbeddedChannel() {
return embeddedChannel;
}
@Nullable
protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() {
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
}
protected abstract CipherStatePair doHandshake() throws Throwable;
/**
* Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no
* waiting messages in the channel, return null.
*/
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
if (responseFrame == null) {
return null;
}
final byte[] plaintext = new byte[responseFrame.readableBytes() - 16];
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
ByteBufUtil.getBytes(responseFrame), 0,
plaintext, 0,
responseFrame.readableBytes());
assertEquals(read, plaintext.length);
return plaintext;
}
@Test
void handleInvalidInitialMessage() throws InterruptedException {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
assertEquals(0, content.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
final ByteBuf[] frames = new ByteBuf[7];
for (int i = 0; i < frames.length; i++) {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
frames[i] = Unpooled.wrappedBuffer(contentBytes);
embeddedChannel.writeOneInbound(frames[i]).await();
}
for (final ByteBuf frame : frames) {
assertEquals(0, frame.refCnt());
}
assertNull(getNoiseHandshakeCompleteEvent());
}
@Test
void handleNonByteBufBinaryFrame() throws Throwable {
final byte[] contentBytes = new byte[17];
ThreadLocalRandom.current().nextBytes(contentBytes);
final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
assertFalse(writeFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
assertEquals(0, message.refCnt());
assertNull(getNoiseHandshakeCompleteEvent());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelRead() throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8);
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext);
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
assertEquals(0, ciphertextFrame.refCnt());
final byte[] response = readNextPlaintext(clientCipherStatePair);
assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response);
}
@Test
void channelReadBadCiphertext() throws Throwable {
doHandshake();
final byte[] bogusCiphertext = new byte[32];
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext);
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
assertEquals(0, ciphertextFrame.refCnt());
assertFalse(readCiphertextFuture.isSuccess());
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void channelReadUnexpectedMessageType() throws Throwable {
doHandshake();
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
assertFalse(readFuture.isSuccess());
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
}
@Test
void write() throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.await().isSuccess());
assertEquals(0, plaintextBuffer.refCnt());
final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
assertNotNull(ciphertextFrame);
assertTrue(embeddedChannel.outboundMessages().isEmpty());
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
ciphertextFrame.release();
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
assertArrayEquals(plaintext, decryptedPlaintext);
}
@Test
void writeUnexpectedMessageType() throws Throwable {
doHandshake();
final Object unexpectedMessaged = new Object();
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
assertTrue(writeFuture.await().isSuccess());
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
assertTrue(embeddedChannel.outboundMessages().isEmpty());
}
@ParameterizedTest
@ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.isSuccess());
final byte[] decryptedPlaintext = new byte[plaintextLength];
int plaintextOffset = 0;
ByteBuf ciphertextFrame;
while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) {
assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN);
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
ciphertextFrame.release();
plaintextOffset += clientCipherStatePair.getReceiver()
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
}
assertArrayEquals(plaintext, decryptedPlaintext);
assertEquals(0, plaintextBuffer.refCnt());
}
@Test
public void writeHugeInboundMessage() throws Throwable {
doHandshake();
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
assertThrows(NoiseException.class, embeddedChannel::checkException);
}
@Test
public void channelAttributes() throws Throwable {
doHandshake();
final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent();
assertEquals(REMOTE_ADDRESS, event.remoteAddress());
assertEquals(USER_AGENT, event.userAgent());
assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage());
}
protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() {
return NoiseTunnelProtos.HandshakeInit.newBuilder()
.setUserAgent(USER_AGENT)
.setAcceptLanguage(ACCEPT_LANGUAGE);
}
}

View File

@@ -1,451 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
private static NioEventLoopGroup nioEventLoopGroup;
private static DefaultEventLoopGroup defaultEventLoopGroup;
private static ExecutorService delegatedTaskExecutor;
private static ExecutorService serverCallExecutor;
private GrpcClientConnectionManager grpcClientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair serverKeyPair;
private ECKeyPair clientKeyPair;
private ManagedLocalGrpcServer authenticatedGrpcServer;
private ManagedLocalGrpcServer anonymousGrpcServer;
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
@BeforeAll
static void setUpBeforeAll() {
nioEventLoopGroup = new NioEventLoopGroup();
defaultEventLoopGroup = new DefaultEventLoopGroup();
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
}
@BeforeEach
void setUp() throws Exception {
clientKeyPair = ECKeyPair.generate();
serverKeyPair = ECKeyPair.generate();
grpcClientConnectionManager = new GrpcClientConnectionManager();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.addService(new EchoServiceImpl())
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
}
};
authenticatedGrpcServer.start();
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder
.executor(serverCallExecutor)
.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
}
};
anonymousGrpcServer.start();
this.start(
nioEventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
}
protected abstract void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception;
protected abstract void stop() throws Exception;
protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey);
public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason)
throws ExecutionException, InterruptedException, TimeoutException {
final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
assertEquals(reason, result.closeReason());
}
@AfterEach
void tearDown() throws Exception {
authenticatedGrpcServer.stop();
anonymousGrpcServer.stop();
this.stop();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
delegatedTaskExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
serverCallExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
try (final NoiseTunnelClient client = authenticated()
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseTunnelClient client = authenticated()
.setServerPublicKey(ECKeyPair.generate().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException {
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertEquals(
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException {
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertEquals(
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void clientNormalClosure() throws InterruptedException {
final NoiseTunnelClient client = anonymous().build();
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
client.close();
// When we gracefully close the tunnel client, we should send an OK close frame
final CloseFrameEvent closeFrame = client.closeFrameFuture().join();
assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator());
assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason());
} finally {
channel.shutdown();
}
}
@Test
void connectAnonymous() throws InterruptedException {
try (final NoiseTunnelClient client = anonymous().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
// Try to verify the server's public key with something other than the key with which it was signed
try (final NoiseTunnelClient client = anonymous()
.setServerPublicKey(ECKeyPair.generate().getPublicKey())
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
return NettyChannelBuilder.forAddress(localAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(defaultEventLoopGroup)
.usePlaintext()
.build();
}
@Test
void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS);
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason());
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator());
} finally {
channel.shutdown();
}
}
}
@Test
void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated().build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
// Start an open-ended server call and leave it in a non-complete state
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
new StreamObserver<>() {
@Override
public void onNext(final EchoResponse echoResponse) {
responseCountDownLatch.countDown();
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onCompleted() {
}
});
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
// truly started before requesting connection closure.
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
try {
client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS);
fail("Channel should not close until active requests have finished");
} catch (TimeoutException e) {
}
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
// Complete the open-ended server call
echoRequestStreamObserver.onCompleted();
final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator());
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason());
} finally {
channel.shutdown();
}
}
}
protected NoiseTunnelClient.Builder anonymous() {
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey());
}
protected NoiseTunnelClient.Builder authenticated() {
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey())
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID);
}
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
return includeProxyMesage
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443)
: null;
}
}

View File

@@ -1,227 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device;
class GrpcClientConnectionManagerTest {
private static EventLoopGroup eventLoopGroup;
private LocalChannel localChannel;
private LocalChannel remoteChannel;
private LocalServerChannel localServerChannel;
private GrpcClientConnectionManager grpcClientConnectionManager;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws InterruptedException {
eventLoopGroup = new DefaultEventLoopGroup();
grpcClientConnectionManager = new GrpcClientConnectionManager();
// We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
// local server to which we can open trivial local connections
localServerChannel = (LocalServerChannel) new ServerBootstrap()
.group(eventLoopGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel channel) {
}
})
.bind(new LocalAddress("test-server"))
.await()
.channel();
final Bootstrap clientBootstrap = new Bootstrap()
.group(eventLoopGroup)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
}
});
localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
}
@AfterEach
void tearDown() throws InterruptedException {
localChannel.close().await();
remoteChannel.close().await();
localServerChannel.close().await();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@ParameterizedTest
@MethodSource
void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
}
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
return List.of(
Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
Optional.empty()
);
}
@Test
void getRequestAttributes() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
}
@Test
void closeConnection() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@ParameterizedTest
@MethodSource
void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel,
preferredRemoteAddress,
userAgentHeader,
acceptLanguageHeader);
assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
}
private static List<Arguments> handleHandshakeInitiatedRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of(
Arguments.argumentSet("Null User-Agent and Accept-Language headers",
preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
);
}
@Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
remoteChannel.close().await();
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
}
}

View File

@@ -1,62 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class HAProxyMessageHandlerTest {
private EmbeddedChannel embeddedChannel;
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler());
}
@Test
void handleHAProxyMessage() throws InterruptedException {
final HAProxyMessage haProxyMessage = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"10.0.0.1", "10.0.0.2", 12345, 443);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage);
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, haProxyMessage.refCnt());
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
}
@Test
void handleNonHAProxyMessage() throws InterruptedException {
final byte[] bytes = new byte[32];
ThreadLocalRandom.current().nextBytes(bytes);
final ByteBuf message = Unpooled.wrappedBuffer(bytes);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message);
embeddedChannel.flushInbound();
writeFuture.await();
assertEquals(1, embeddedChannel.inboundMessages().size());
assertEquals(message, embeddedChannel.inboundMessages().poll());
assertEquals(1, message.refCnt());
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
}
}

View File

@@ -1,116 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.protobuf.ByteString;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.Test;
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
@Override
protected CipherStatePair doHandshake() throws Exception {
return doHandshake(baseHandshakeInit().build().toByteArray());
}
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState =
new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
clientHandshakeState.start();
// Send initiator handshake message
// 32 byte key, request payload, 16 byte AEAD tag
final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16;
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength];
assertEquals(
initiateHandshakeMessageLength,
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
final NoiseHandshakeInit message = new NoiseHandshakeInit(
REMOTE_ADDRESS,
HandshakePattern.NK,
Unpooled.wrappedBuffer(initiateHandshakeMessage));
assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess());
assertEquals(0, message.refCnt());
embeddedChannel.runPendingTasks();
// Read responder handshake message
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
assertNotNull(responderHandshakeFrame);
final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame);
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
.build();
// ephemeral key, payload, AEAD tag
assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length);
final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()];
assertEquals(expectedHandshakeResponse.getSerializedSize(),
clientHandshakeState.readMessage(
responderHandshakeBytes, 0, responderHandshakeBytes.length,
handshakeResponsePlaintext, 0));
assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext));
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
return clientHandshakeState.split();
}
@Test
void handleCompleteHandshakeWithRequest() throws Exception {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final byte[] handshakePlaintext = baseHandshakeInit()
.setFastOpenRequest(ByteString.copyFromUtf8("ping")).build()
.toByteArray();
final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext);
final byte[] response = readNextPlaintext(cipherStatePair);
assertArrayEquals(response, "pong".getBytes());
assertEquals(
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake());
assertNull(readNextPlaintext(cipherStatePair));
assertEquals(
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
getNoiseHandshakeCompleteEvent());
}
}

View File

@@ -1,338 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.internal.EmptyArrays;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadLocalRandom;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
private final ECKeyPair clientKeyPair = ECKeyPair.generate();
@Override
protected CipherStatePair doHandshake() throws Throwable {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
return doHandshake(identityPayload(accountIdentifier, deviceId));
}
@Test
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
assertEquals(
new NoiseIdentityDeterminedEvent(
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId)
.setFastOpenRequest(ByteString.copyFromUtf8("ping"))
.build()
.toByteArray();
final byte[] response = readNextPlaintext(doHandshake(handshakeInit));
assertEquals(4, response.length);
assertEquals("pong", new String(response));
assertEquals(
new NoiseIdentityDeterminedEvent(
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
getNoiseHandshakeCompleteEvent());
}
@Test
void handleCompleteHandshakeMissingIdentityInformation() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
verifyNoInteractions(clientPublicKeysManager);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeMalformedIdentityInformation() {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
// no deviceId byte
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload));
verifyNoInteractions(clientPublicKeysManager);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakeUnrecognizedDevice() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
doHandshake(
identityPayload(accountIdentifier, deviceId),
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
"Noise stream handler should not be added to pipeline after failed handshake");
}
@Test
void handleCompleteHandshakePublicKeyMismatch() throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
doHandshake(
identityPayload(accountIdentifier, deviceId),
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
assertNull(getNoiseHandshakeCompleteEvent());
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
"Handshake handler should not remove self from pipeline after failed handshake");
}
@Test
void handleInvalidExtraWrites()
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
final HandshakeState clientHandshakeState = clientHandshakeState();
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit(
REMOTE_ADDRESS,
HandshakePattern.IK,
Unpooled.wrappedBuffer(
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess());
// While waiting for the public key, send another message
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
assertInstanceOf(IllegalArgumentException.class, f.exceptionNow());
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
embeddedChannel.runPendingTasks();
}
@Test
public void handleOversizeHandshakeMessage() {
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
ByteBuffer.wrap(big)
.put(UUIDUtil.toBytes(UUID.randomUUID()))
.put((byte) 0x01);
assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
}
@Test
public void handleKeyLookupError() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = randomDeviceId();
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
.thenReturn(CompletableFuture.failedFuture(new IOException()));
assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
}
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
final HandshakeState clientHandshakeState =
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
clientHandshakeState.start();
return clientHandshakeState;
}
private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload)
throws ShortBufferException {
// Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag
final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16];
int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length);
assertEquals(written, initiatorMessageBytes.length);
return initiatorMessageBytes;
}
private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message)
throws ShortBufferException, BadPaddingException {
// 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload
final int expectedResponsePayloadLength = message.length - 32 - 16;
final byte[] responsePayload = new byte[expectedResponsePayloadLength];
final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0);
assertEquals(expectedResponsePayloadLength, responsePayloadLength);
return responsePayload;
}
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK);
}
private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable {
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
final HandshakeState clientHandshakeState = clientHandshakeState();
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
final NoiseHandshakeInit initMessage = new NoiseHandshakeInit(
REMOTE_ADDRESS,
HandshakePattern.IK,
Unpooled.wrappedBuffer(initiatorMessage));
final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await();
assertEquals(0, initMessage.refCnt());
if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
throw await.cause();
}
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
// and issue a "handshake complete" event.
embeddedChannel.runPendingTasks();
// rethrow if running the task caused an error, and the caller isn't expecting an error
if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
embeddedChannel.checkException();
}
assertFalse(embeddedChannel.outboundMessages().isEmpty());
final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
assertNotNull(handshakeResponseFrame);
final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame);
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(expectedStatus)
.build();
final byte[] actualHandshakeResponsePlaintext =
readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes);
assertEquals(
expectedHandshakeResponsePlaintext,
NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext));
final byte[] serverPublicKey = new byte[32];
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
return clientHandshakeState.split();
}
private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) {
return baseHandshakeInit()
.setAci(UUIDUtil.toByteString(accountIdentifier))
.setDeviceId(deviceId);
}
private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
return identifiedHandshakeInit(accountIdentifier, deviceId)
.build()
.toByteArray();
}
private static byte randomDeviceId() {
return (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1);
}
}

View File

@@ -1,66 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import com.southernstorm.noise.protocol.HandshakeState;
import io.netty.buffer.ByteBuf;
import java.nio.charset.StandardCharsets;
import javax.crypto.ShortBufferException;
import io.netty.buffer.ByteBufUtil;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper;
public class NoiseHandshakeHelperTest {
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8));
}
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]);
}
@ParameterizedTest
@EnumSource(HandshakePattern.class)
void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
doHandshake(pattern, new byte[0], new byte[0]);
}
void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException {
final ECKeyPair serverKeyPair = ECKeyPair.generate();
final ECKeyPair clientKeyPair = ECKeyPair.generate();
NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair);
NoiseClientHandshakeHelper clientHelper = switch (pattern) {
case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair);
case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey());
};
final byte[] initiate = clientHelper.write(requestPayload);
final ByteBuf actualRequestPayload = serverHelper.read(initiate);
assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload);
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE);
final byte[] respond = serverHelper.write(responsePayload);
byte[] actualResponsePayload = clientHelper.read(respond);
assertThat(actualResponsePayload).isEqualTo(responsePayload);
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT);
assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split());
assertThatNoException().isThrownBy(() -> clientHelper.split());
}
}

View File

@@ -1,108 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.util.HexFormat;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class ProxyProtocolDetectionHandlerTest {
private EmbeddedChannel embeddedChannel;
private static final byte[] PROXY_V2_MESSAGE_BYTES =
HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb");
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler());
}
@Test
void singlePacketProxyMessage() throws InterruptedException {
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES));
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
}
@Test
void multiPacketProxyMessage() throws InterruptedException {
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0,
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
embeddedChannel.flushInbound();
firstWriteFuture.await();
secondWriteFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
}
@Test
void singlePacketNonProxyMessage() throws InterruptedException {
final byte[] nonProxyProtocolMessage = new byte[32];
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage));
embeddedChannel.flushInbound();
writeFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
assertInstanceOf(ByteBuf.class, inboundMessage);
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
}
@Test
void multiPacketNonProxyMessage() throws InterruptedException {
final byte[] nonProxyProtocolMessage = new byte[32];
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0,
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
embeddedChannel.flushInbound();
firstWriteFuture.await();
secondWriteFuture.await();
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
assertEquals(1, embeddedChannel.inboundMessages().size());
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
assertInstanceOf(ByteBuf.class, inboundMessage);
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
}
}

View File

@@ -1,16 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ClientErrorHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class);
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
log.error("Caught inbound error in client; closing connection", cause);
context.channel().close();
}
}

View File

@@ -1,53 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
public enum CloseReason {
OK,
SERVER_CLOSED,
NOISE_ERROR,
NOISE_HANDSHAKE_ERROR,
INTERNAL_SERVER_ERROR,
UNKNOWN
}
public enum CloseInitiator {
SERVER,
CLIENT
}
public static CloseFrameEvent fromWebsocketCloseFrame(
CloseWebSocketFrame closeWebSocketFrame,
CloseInitiator closeInitiator) {
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
case 4002 -> CloseReason.NOISE_ERROR;
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
case 1012 -> CloseReason.SERVER_CLOSED;
case 1000 -> CloseReason.OK;
default -> CloseReason.UNKNOWN;
};
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
}
public static CloseFrameEvent fromNoiseDirectCloseFrame(
NoiseDirectProtos.CloseReason noiseDirectCloseReason,
CloseInitiator closeInitiator) {
final CloseReason code = switch (noiseDirectCloseReason.getCode()) {
case OK -> CloseReason.OK;
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
};
return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage());
}
}

View File

@@ -1,120 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
/**
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
* gRPC server.
* <p>
* This handler waits until the first gRPC client message is ready and then establishes a connection with the remote
* gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote
* connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the
* handshake is finished.
*/
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final List<ChannelHandler> remoteHandlerStack;
private final NoiseTunnelProtos.HandshakeInit handshakeInit;
private final SocketAddress remoteServerAddress;
// If provided, will be sent with the payload in the noise handshake
private final List<Object> pendingReads = new ArrayList<>();
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
EstablishRemoteConnectionHandler(
final List<ChannelHandler> remoteHandlerStack,
final SocketAddress remoteServerAddress,
final NoiseTunnelProtos.HandshakeInit handshakeInit) {
this.remoteHandlerStack = remoteHandlerStack;
this.handshakeInit = handshakeInit;
this.remoteServerAddress = remoteServerAddress;
}
@Override
public void handlerAdded(final ChannelHandlerContext localContext) {
new Bootstrap()
.channel(NioSocketChannel.class)
.group(localContext.channel().eventLoop())
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(final SocketChannel channel) throws Exception {
for (ChannelHandler handler : remoteHandlerStack) {
channel.pipeline().addLast(handler);
}
channel.pipeline()
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
throws Exception {
switch (event) {
case ReadyForNoiseHandshakeEvent ignored ->
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray()))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> {
remoteContext.pipeline()
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
// If there was a payload response on the handshake, write it back to our gRPC client
if (!handshakeResponse.getFastOpenResponse().isEmpty()) {
localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse
.getFastOpenResponse()
.asReadOnlyByteBuffer()));
}
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
pendingReads.forEach(localContext::fireChannelRead);
pendingReads.clear();
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
}
default -> {
}
}
super.userEventTriggered(remoteContext, event);
}
})
.addLast(new ClientErrorHandler());
}
})
.connect(remoteServerAddress)
.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
// Close the local connection if the remote channel closes and vice versa
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
} else {
localContext.close();
}
});
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
pendingReads.add(message);
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
}

View File

@@ -1,9 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.buffer.ByteBuf;
public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}

View File

@@ -1,28 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.util.function.Supplier;
class HAProxyMessageSender extends ChannelInboundHandlerAdapter {
private final Supplier<HAProxyMessage> messageSupplier;
HAProxyMessageSender(final Supplier<HAProxyMessage> messageSupplier) {
this.messageSupplier = messageSupplier;
}
@Override
public void handlerAdded(final ChannelHandlerContext context) {
if (context.channel().isActive()) {
context.writeAndFlush(messageSupplier.get());
}
}
@Override
public void channelActive(final ChannelHandlerContext context) {
context.writeAndFlush(messageSupplier.get());
context.fireChannelActive();
}
}

View File

@@ -1,185 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.ReferenceCountUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
import java.util.stream.Stream;
/**
* The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast
* open" optimization where the client can send a request along with the noise handshake. There's no direct way to
* extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an
* inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request.
* <p>
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
* {@link FastOpenRequestBufferedEvent}
*/
public class Http2Buffering {
/**
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
*/
public static ChannelInboundHandler handler() {
return new Http2PrefaceHandler();
}
private Http2Buffering() {
}
private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter {
// https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5
private static final byte[] HTTP2_PREFACE =
HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a");
private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length);
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
if (message instanceof ByteBuf bb) {
bb.readBytes(read);
if (read.readableBytes() < HTTP2_PREFACE.length) {
// Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more.
return;
}
if (!Arrays.equals(read.array(), HTTP2_PREFACE)) {
throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface");
}
context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder());
context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder());
context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler());
context.fireChannelRead(bb);
} else {
throw new IllegalStateException("Unexpected message: " + message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
ReferenceCountUtil.release(read);
}
}
private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) {
private static final byte FLAG_END_STREAM = 0x01;
enum FrameType {
SETTINGS,
HEADERS,
DATA,
WINDOW_UPDATE,
OTHER;
static FrameType fromSerializedType(final byte type) {
return switch (type) {
case 0x00 -> Http2Frame.FrameType.DATA;
case 0x01 -> Http2Frame.FrameType.HEADERS;
case 0x04 -> Http2Frame.FrameType.SETTINGS;
case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE;
default -> Http2Frame.FrameType.OTHER;
};
}
}
}
/**
* Emit ByteBuf of entire HTTP/2 frame
*/
private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder {
public Http2LengthFieldFrameDecoder() {
// Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload
super(16 * 1024 * 1024, 0, 3, 6, 0);
}
}
/**
* Parse the serialized Http/2 frames into {@link Http2Frame} objects
*/
private static class Http2FrameDecoder extends ByteToMessageDecoder {
@Override
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) throws Exception {
// https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1
final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3));
final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4));
out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream));
}
boolean endStream(Http2Frame.FrameType frameType, byte flags) {
// A gRPC request are packed into HTTP/2 frames like:
// HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1)
//
// Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set
// we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about
// is having at least one complete request. In total, we can expect something like:
// HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1)
//
// The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack
// bit set, but HTTP/2 allows the client to send frames before getting the ACK.
if (frameType == Http2Frame.FrameType.DATA) {
return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM;
}
// In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete
// (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a
// DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we
// might as well pack that in.
return frameType == Http2Frame.FrameType.WINDOW_UPDATE;
}
}
/**
* Collect HTTP/2 frames until we get an entire "request" to send
*/
private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter {
final List<Http2Frame> pendingFrames = new ArrayList<>();
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
if (message instanceof Http2Frame http2Frame) {
if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) {
throw new IllegalStateException(
"HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type);
}
pendingFrames.add(http2Frame);
if (http2Frame.endStream) {
// We have a whole "request", emit the first request event and remove the http2 buffering handlers
final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat(
Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)),
pendingFrames.stream().map(Http2Frame::bytes))
.toArray(ByteBuf[]::new));
pendingFrames.clear();
context.pipeline().remove(Http2LengthFieldFrameDecoder.class);
context.pipeline().remove(Http2FrameDecoder.class);
context.pipeline().remove(this);
context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request));
}
} else {
throw new IllegalStateException("Unexpected message: " + message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes()));
pendingFrames.clear();
}
}
}

View File

@@ -1,17 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
import java.util.Optional;
/**
* A netty user event that indicates that the noise handshake finished successfully.
*
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
* the server included a payload in the handshake response.
*/
public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {}

View File

@@ -1,62 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
private final NoiseClientHandshakeHelper handshakeHelper;
public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) {
this.handshakeHelper = handshakeHelper;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof ByteBuf plaintextHandshakePayload) {
final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload,
plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(),
false);
final byte[] handshakeMessage = handshakeHelper.write(payloadBytes);
ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise);
} else {
ctx.write(msg, promise);
}
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message)
throws NoiseHandshakeException {
if (message instanceof ByteBuf frame) {
try {
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
final NoiseTunnelProtos.HandshakeResponse handshakeResponse =
NoiseTunnelProtos.HandshakeResponse.parseFrom(payload);
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse));
} catch (InvalidProtocolBufferException e) {
throw new NoiseHandshakeException("Failed to parse handshake response");
} finally {
frame.release();
}
} else {
context.fireChannelRead(message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
handshakeHelper.destroy();
}
}

View File

@@ -1,95 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.HandshakeState;
import java.security.NoSuchAlgorithmException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
public class NoiseClientHandshakeHelper {
private final HandshakePattern handshakePattern;
private final HandshakeState handshakeState;
private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) {
this.handshakePattern = handshakePattern;
this.handshakeState = handshakeState;
}
public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
state.start();
return new NoiseClientHandshakeHelper(HandshakePattern.IK, state);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
try {
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
state.start();
return new NoiseClientHandshakeHelper(HandshakePattern.NK, state);
} catch (NoSuchAlgorithmException e) {
throw new IllegalArgumentException(e);
}
}
public byte[] write(final byte[] requestPayload) throws ShortBufferException {
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
return initiateHandshakeMessage;
}
private int initiateHandshakeKeysLength() {
return switch (handshakePattern) {
// 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag
case IK -> 32 + 32 + 16;
// 32-byte ephemeral key
case NK -> 32;
};
}
public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
// Don't process additional messages if the handshake failed and we're just waiting to close
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
}
final int payloadLength = responderHandshakeMessage.length - 16 - 32;
final byte[] responsePayload = new byte[payloadLength];
final int payloadBytesRead;
try {
payloadBytesRead = handshakeState
.readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0);
if (payloadBytesRead != responsePayload.length) {
throw new IllegalStateException(
"Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead);
}
return responsePayload;
} catch (ShortBufferException e) {
throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage());
} catch (BadPaddingException e) {
throw new NoiseHandshakeException(e.getMessage());
}
}
public CipherStatePair split() {
return this.handshakeState.split();
}
public void destroy() {
this.handshakeState.destroy();
}
}

View File

@@ -1,89 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.CipherStatePair;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
/**
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
*/
public class NoiseClientTransportHandler extends ChannelDuplexHandler {
private final CipherStatePair cipherStatePair;
private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class);
NoiseClientTransportHandler(CipherStatePair cipherStatePair) {
this.cipherStatePair = cipherStatePair;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof ByteBuf frame) {
final CipherState cipherState = cipherStatePair.getReceiver();
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame);
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
} else {
// Anything except binary frames should have been filtered out of the pipeline by now; treat this as an
// error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} finally {
ReferenceCountUtil.release(message);
}
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf plaintext) {
try {
final CipherState cipherState = cipherStatePair.getSender();
final int plaintextLength = plaintext.readableBytes();
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
context.write(Unpooled.wrappedBuffer(noiseBuffer), promise);
} finally {
ReferenceCountUtil.release(plaintext);
}
} else {
if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) {
// Clients only write ByteBufs or a close frame on errors, so any other message is unexpected
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
cipherStatePair.destroy();
}
}

View File

@@ -1,408 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
import com.google.protobuf.ByteString;
import com.southernstorm.noise.protocol.Noise;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.MessageToMessageCodec;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.ReferenceCountUtil;
import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
public class NoiseTunnelClient implements AutoCloseable {
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
private final CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture;
private final CompletableFuture<Void> userCloseFuture;
private final ServerBootstrap serverBootstrap;
private Channel serverChannel;
public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
public static class Builder {
final SocketAddress remoteServerAddress;
NioEventLoopGroup eventLoopGroup;
ECPublicKey serverPublicKey;
FramingType framingType = FramingType.WEBSOCKET;
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
HttpHeaders headers = new DefaultHttpHeaders();
NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder();
boolean authenticated = false;
ECKeyPair ecKeyPair = null;
boolean useTls;
X509Certificate trustedServerCertificate = null;
Supplier<HAProxyMessage> proxyMessageSupplier = null;
public Builder(
final SocketAddress remoteServerAddress,
final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
this.remoteServerAddress = remoteServerAddress;
this.eventLoopGroup = eventLoopGroup;
this.serverPublicKey = serverPublicKey;
}
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
this.authenticated = true;
handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier));
handshakeInit.setDeviceId(deviceId);
this.ecKeyPair = ecKeyPair;
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
return this;
}
public Builder setWebsocketUri(final URI websocketUri) {
this.websocketUri = websocketUri;
return this;
}
public Builder setUseTls(X509Certificate trustedServerCertificate) {
this.useTls = true;
this.trustedServerCertificate = trustedServerCertificate;
return this;
}
public Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
this.proxyMessageSupplier = proxyMessageSupplier;
return this;
}
public Builder setUserAgent(final String userAgent) {
handshakeInit.setUserAgent(userAgent);
return this;
}
public Builder setAcceptLanguage(final String acceptLanguage) {
handshakeInit.setAcceptLanguage(acceptLanguage);
return this;
}
public Builder setHeaders(final HttpHeaders headers) {
this.headers = headers;
return this;
}
public Builder setServerPublicKey(ECPublicKey serverPublicKey) {
this.serverPublicKey = serverPublicKey;
return this;
}
public Builder setFramingType(FramingType framingType) {
this.framingType = framingType;
return this;
}
public NoiseTunnelClient build() {
final List<ChannelHandler> handlers = new ArrayList<>();
if (proxyMessageSupplier != null) {
handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier)));
}
if (useTls) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (trustedServerCertificate != null) {
sslContextBuilder.trustManager(trustedServerCertificate);
}
try {
handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT));
} catch (SSLException e) {
throw new IllegalArgumentException(e);
}
}
// handles the wrapping and unrwrapping the framing layer (websockets or noisedirect)
handlers.addAll(switch (framingType) {
case WEBSOCKET -> websocketHandlerStack(websocketUri, headers);
case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated);
});
final NoiseClientHandshakeHelper helper = authenticated
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
: NoiseClientHandshakeHelper.NK(serverPublicKey);
handlers.add(new NoiseClientHandshakeHandler(helper));
// When the noise handshake completes we'll save the response from the server so client users can inspect it
final UserEventFuture<NoiseClientHandshakeCompleteEvent> handshakeEventHandler =
new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class);
handlers.add(handshakeEventHandler);
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
// information about why the connection was closed.
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
handlers.add(closeEventHandler);
// When the user closes the client, write a normal closure close frame
final CompletableFuture<Void> userCloseFuture = new CompletableFuture<>();
handlers.add(new ChannelInboundHandlerAdapter() {
@Override
public void handlerAdded(final ChannelHandlerContext ctx) {
userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) {
case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE);
case NOISE_DIRECT -> new NoiseDirectFrame(
NoiseDirectFrame.FrameType.CLOSE,
Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason
.newBuilder()
.setCode(NoiseDirectProtos.CloseReason.Code.OK)
.build()
.toByteArray()));
})
.addListener(ChannelFutureListener.CLOSE),
ctx.executor());
}
});
final NoiseTunnelClient client =
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler(
handlers,
remoteServerAddress,
handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build()));
client.start();
return client;
}
}
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
CompletableFuture<CloseFrameEvent> closeEventFuture,
CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture,
CompletableFuture<Void> userCloseFuture,
Function<byte[], EstablishRemoteConnectionHandler> handler) {
this.userCloseFuture = userCloseFuture;
this.closeEventFuture = closeEventFuture;
this.handshakeEventFuture = handshakeEventFuture;
this.serverBootstrap = new ServerBootstrap()
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
.channel(LocalServerChannel.class)
.group(eventLoopGroup)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline()
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
// in the handshake.
.addLast(Http2Buffering.handler())
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
// connect to the remote service
.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) {
byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest);
fastOpenRequest.release();
ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes));
}
super.userEventTriggered(ctx, evt);
}
})
.addLast(new ClientErrorHandler());
}
});
}
private static class UserEventFuture<T> extends ChannelInboundHandlerAdapter {
private final CompletableFuture<T> future = new CompletableFuture<>();
private final Class<T> cls;
UserEventFuture(Class<T> cls) {
this.cls = cls;
}
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) {
if (cls.isInstance(evt)) {
future.complete((T) evt);
}
ctx.fireUserEventTriggered(evt);
}
}
public LocalAddress getLocalAddress() {
return (LocalAddress) serverChannel.localAddress();
}
private NoiseTunnelClient start() {
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
return this;
}
@Override
public void close() throws InterruptedException {
userCloseFuture.complete(null);
serverChannel.close().await();
}
/**
* @return A future that completes when a close frame is observed
*/
public CompletableFuture<CloseFrameEvent> closeFrameFuture() {
return closeEventFuture;
}
/**
* @return A future that completes when the noise handshake finishes
*/
public CompletableFuture<NoiseClientHandshakeCompleteEvent> getHandshakeEventFuture() {
return handshakeEventFuture;
}
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
return List.of(
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
new NoiseDirectFrameCodec(),
new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
ctx.fireChannelActive();
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
try {
final NoiseDirectProtos.CloseReason closeReason =
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
ctx.fireUserEventTriggered(
CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER));
} finally {
ReferenceCountUtil.release(msg);
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
final NoiseDirectProtos.CloseReason errorPayload =
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
ctx.fireUserEventTriggered(
CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
}
ctx.write(msg, promise);
}
},
new MessageToMessageCodec<NoiseDirectFrame, ByteBuf>() {
boolean noiseHandshakeFinished = false;
@Override
protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List<Object> out) {
final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished
? NoiseDirectFrame.FrameType.DATA
: (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE);
noiseHandshakeFinished = true;
out.add(new NoiseDirectFrame(frameType, msg.retain()));
}
@Override
protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg,
final List<Object> out) {
out.add(msg.content().retain());
}
});
}
private static List<ChannelHandler> websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) {
return List.of(
new HttpClientCodec(),
new HttpObjectAggregator(Noise.MAX_PACKET_LEN),
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
// want to react to them on our own, we need to catch them before they hit that handler.
new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
context.fireUserEventTriggered(
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER));
}
super.channelRead(context, message);
}
},
new WebSocketClientProtocolHandler(websocketUri,
WebSocketVersion.V13,
null,
false,
headers,
Noise.MAX_PACKET_LEN,
10_000),
new ChannelOutboundHandlerAdapter() {
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
context.fireUserEventTriggered(
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT));
}
super.write(context, message, promise);
}
},
new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
}
}
context.fireUserEventTriggered(event);
}
},
new WebsocketPayloadCodec());
}
}

View File

@@ -1,4 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.client;
public record ReadyForNoiseHandshakeEvent() {
}

View File

@@ -1,50 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.util.concurrent.Executor;
class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseDirectTunnelServer noiseDirectTunnelServer;
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
noiseDirectTunnelServer = new NoiseDirectTunnelServer(0,
eventLoopGroup,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress);
noiseDirectTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
noiseDirectTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
.setFramingType(FramingType.NOISE_DIRECT);
}
}

View File

@@ -1,72 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
private EmbeddedChannel embeddedChannel;
@BeforeEach
void setUp() {
embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler());
}
@ParameterizedTest
@MethodSource
void allowWebSocketFrame(final WebSocketFrame frame) {
embeddedChannel.writeOneInbound(frame);
try {
assertEquals(frame, embeddedChannel.inboundMessages().poll());
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(1, frame.refCnt());
} finally {
frame.release();
}
}
private static List<WebSocketFrame> allowWebSocketFrame() {
return List.of(
new BinaryWebSocketFrame(),
new CloseWebSocketFrame(),
new ContinuationWebSocketFrame(),
new PingWebSocketFrame(),
new PongWebSocketFrame());
}
@Test
void rejectTextFrame() {
final TextWebSocketFrame textFrame = new TextWebSocketFrame();
embeddedChannel.writeOneInbound(textFrame);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, textFrame.refCnt());
}
@Test
void rejectNonWebSocketFrame() {
final ByteBuf bytes = Unpooled.buffer(0);
embeddedChannel.writeOneInbound(bytes);
assertTrue(embeddedChannel.inboundMessages().isEmpty());
assertEquals(0, bytes.refCnt());
}
}

View File

@@ -1,239 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeoutException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
private X509Certificate serverTlsCertificate;
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with:
//
// ```shell
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
// ```
private static final String SERVER_CERTIFICATE = """
-----BEGIN CERTIFICATE-----
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
iOr9sHiO8Rn2u0xRKgU5Ig==
-----END CERTIFICATE-----
""";
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
private static final String SERVER_PRIVATE_KEY = """
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
""";
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
final PrivateKey serverTlsPrivateKey;
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
serverTlsPrivateKey =
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
new X509Certificate[]{serverTlsCertificate},
serverTlsPrivateKey,
eventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
recognizedProxySecret);
tlsNoiseWebSocketTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
tlsNoiseWebSocketTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup,
final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
.setUseTls(serverTlsCertificate);
}
@Test
void getRequestAttributes() throws InterruptedException {
final String remoteAddress = "4.5.6.7";
final String acceptLanguage = "en";
final String userAgent = "Signal-Desktop/1.2.3 Linux";
final HttpHeaders headers = new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add("X-Forwarded-For", remoteAddress);
try (final NoiseTunnelClient client = anonymous()
.setHeaders(headers)
.setUserAgent(userAgent)
.setAcceptLanguage(acceptLanguage)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals(userAgent, response.getUserAgent());
} finally {
channel.shutdown();
}
}
}
@Test
void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = authenticated()
.setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException {
try (final NoiseTunnelClient client = anonymous()
.setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
.build()) {
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
}
}
@Test
void rejectIllegalRequests() throws Exception {
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
final TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
final URI authenticatedUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated",
null, null);
final URI incorrectUri =
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect",
null, null);
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
.uri(authenticatedUri)
.PUT(HttpRequest.BodyPublishers.ofString("test"))
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"Non-GET requests should not be allowed");
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(authenticatedUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests without upgrade headers should not be allowed");
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
.GET()
.uri(incorrectUri)
.build(),
HttpResponse.BodyHandlers.ofString()).statusCode(),
"GET requests to unrecognized URIs should not be allowed");
}
}
}

View File

@@ -1,50 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import java.util.concurrent.Executor;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
@Override
protected void start(
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair serverKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws Exception {
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
null,
null,
eventLoopGroup,
delegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
recognizedProxySecret);
plaintextNoiseWebSocketTunnelServer.start();
}
@Override
protected void stop() throws InterruptedException {
plaintextNoiseWebSocketTunnelServer.stop();
}
@Override
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
return new NoiseTunnelClient
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey);
}
}

View File

@@ -1,115 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
private EmbeddedChannel embeddedChannel;
private static final String AUTHENTICATED_PATH = "/authenticated";
private static final String ANONYMOUS_PATH = "/anonymous";
private static final String HEALTH_CHECK_PATH = "/health-check";
@BeforeEach
void setUp() {
embeddedChannel =
new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH, HEALTH_CHECK_PATH));
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleValidRequest(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.GET, path,
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
try {
embeddedChannel.writeOneInbound(request);
assertEquals(1, request.refCnt());
assertEquals(1, embeddedChannel.inboundMessages().size());
assertEquals(request, embeddedChannel.inboundMessages().poll());
} finally {
request.release();
}
}
@Test
void handleHealthCheckRequest() {
final FullHttpRequest request = buildRequest(HttpMethod.GET, HEALTH_CHECK_PATH, new DefaultHttpHeaders());
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.NO_CONTENT);
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleUpgradeRequired(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders());
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED);
}
@Test
void handleBadPath() {
final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect",
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.NOT_FOUND);
}
@ParameterizedTest
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
void handleMethodNotAllowed(final String path) {
final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path,
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
embeddedChannel.writeOneInbound(request);
assertEquals(0, request.refCnt());
assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED);
}
private void assertHttpResponse(final HttpResponseStatus expectedStatus) {
assertEquals(1, embeddedChannel.outboundMessages().size());
final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll());
assertEquals(expectedStatus, response.status());
}
private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) {
return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
method,
path,
Unpooled.buffer(0),
headers,
new DefaultHttpHeaders());
}
}

View File

@@ -1,233 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import com.google.common.net.InetAddresses;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private MutableRemoteAddressEmbeddedChannel embeddedChannel;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel {
private SocketAddress remoteAddress;
public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) {
super(handlers);
}
@Override
protected SocketAddress remoteAddress0() {
return isActive() ? remoteAddress : null;
}
public void setRemoteAddress(final SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET),
userEventRecordingHandler);
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
final byte[] payload = TestRandomUtil.nextBytes(100);
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
assertNotNull(init);
assertEquals(init.getHandshakePattern(), pattern);
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK),
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
@ParameterizedTest
@MethodSource
void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(
NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null);
embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
final byte[] payload = TestRandomUtil.nextBytes(100);
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
assertEquals(
expectedRemoteAddress,
Optional.ofNullable(init)
.map(NoiseHandshakeInit::getRemoteAddress)
.orElse(null));
if (expectedRemoteAddress == null) {
assertThrows(IllegalStateException.class, embeddedChannel::checkException);
} else {
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
}
}
private static List<Arguments> getRemoteAddress() {
final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0);
final InetAddress clientAddress = InetAddresses.forString("1.2.3.4");
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of(
argumentSet("Recognized proxy, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
clientAddress),
argumentSet("Recognized proxy, multiple forwarded-for addresses",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress,
proxyAddress),
argumentSet("No recognized proxy header, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("No recognized proxy header, no forwarded-for address",
new DefaultHttpHeaders(),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Incorrect proxy header, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Recognized proxy, no forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress,
remoteAddress.getAddress()),
argumentSet("Recognized proxy, bogus forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress,
null),
argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"),
null)
);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource("argumentsForGetMostRecentProxy")
void getMostRecentProxy(final String forwardedFor, final Optional<String> expectedMostRecentProxy) {
assertEquals(expectedMostRecentProxy, WebsocketHandshakeCompleteHandler.getMostRecentProxy(forwardedFor));
}
private static Stream<Arguments> argumentsForGetMostRecentProxy() {
return Stream.of(
arguments(null, Optional.empty()),
arguments("", Optional.empty()),
arguments(" ", Optional.empty()),
arguments("203.0.113.195,", Optional.empty()),
arguments("203.0.113.195, ", Optional.empty()),
arguments("203.0.113.195", Optional.of("203.0.113.195")),
arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178"))
);
}
}