diff --git a/pom.xml b/pom.xml index 26659788a..2329a9da3 100644 --- a/pom.xml +++ b/pom.xml @@ -299,11 +299,6 @@ simple-grpc-runtime ${simple-grpc.version} - - org.signal.forks - noise-java - 0.1.1 - org.apache.logging.log4j log4j-bom diff --git a/service/config/sample-secrets-bundle.yml b/service/config/sample-secrets-bundle.yml index c136bf451..d499dcde6 100644 --- a/service/config/sample-secrets-bundle.yml +++ b/service/config/sample-secrets-bundle.yml @@ -102,7 +102,3 @@ turn.cloudflare.apiToken: ABCDEFGHIJKLM linkDevice.secret: AAAAAAAAAAA= tlsKeyStore.password: unset - -noiseTunnel.tlsKeyStorePassword: ABCDEFGHIJKLMNOPQRSTUVWXYZ -noiseTunnel.noiseStaticPrivateKey: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA diff --git a/service/config/sample.yml b/service/config/sample.yml index 9f684c2f9..24a35bcbe 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -529,15 +529,6 @@ turn: linkDevice: secret: secret://linkDevice.secret -noiseTunnel: - webSocketPort: 8444 - directPort: 8445 - tlsKeyStoreFile: /path/to/file.p12 - tlsKeyStoreEntryAlias: example.com - tlsKeyStorePassword: secret://noiseTunnel.tlsKeyStorePassword - noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey - recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret - externalRequestFilter: grpcMethods: - com.example.grpc.ExampleService/exampleMethod @@ -548,3 +539,6 @@ externalRequestFilter: idlePrimaryDeviceReminder: minIdleDuration: P30D + +grpc: + port: 50051 diff --git a/service/pom.xml b/service/pom.xml index f9f27173c..364eedc56 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -82,11 +82,6 @@ libsignal-server - - org.signal.forks - noise-java - - org.signal simple-grpc-runtime diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 9e7d1f417..2d717a4bf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -39,13 +39,13 @@ import org.whispersystems.textsecuregcm.configuration.FcmConfiguration; import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration; import org.whispersystems.textsecuregcm.configuration.GenericZkConfig; import org.whispersystems.textsecuregcm.configuration.GooglePlayBillingConfiguration; +import org.whispersystems.textsecuregcm.configuration.GrpcConfiguration; import org.whispersystems.textsecuregcm.configuration.IdlePrimaryDeviceReminderConfiguration; import org.whispersystems.textsecuregcm.configuration.KeyTransparencyServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.LinkDeviceSecretConfiguration; import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; -import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.OpenTelemetryConfiguration; import org.whispersystems.textsecuregcm.configuration.PagedSingleUseKEMPreKeyStoreConfiguration; @@ -315,11 +315,6 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration(); - @Valid - @NotNull - @JsonProperty - private NoiseTunnelConfiguration noiseTunnel; - @Valid @NotNull @JsonProperty @@ -348,6 +343,11 @@ public class WhisperServerConfiguration extends Configuration { @NotNull private RetryConfiguration generalRedisRetry = new RetryConfiguration(); + @NotNull + @Valid + @JsonProperty + private GrpcConfiguration grpc; + public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() { return tlsKeyStore; } @@ -551,10 +551,6 @@ public class WhisperServerConfiguration extends Configuration { return virtualThread; } - public NoiseTunnelConfiguration getNoiseTunnelConfiguration() { - return noiseTunnel; - } - public ExternalRequestFilterConfiguration getExternalRequestFilterConfiguration() { return externalRequestFilter; } @@ -582,4 +578,8 @@ public class WhisperServerConfiguration extends Configuration { public RetryConfiguration getGeneralRedisRetryConfiguration() { return generalRedisRetry; } + + public GrpcConfiguration getGrpc() { + return grpc; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2bb6110e3..4b808bf0c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -23,12 +23,14 @@ import io.dropwizard.core.setup.Environment; import io.dropwizard.jetty.HttpsConnectorFactory; import io.dropwizard.lifecycle.setup.LifecycleEnvironment; import io.grpc.ServerBuilder; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.netty.NettyServerBuilder; import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder; import io.lettuce.core.metrics.MicrometerOptions; import io.lettuce.core.resource.ClientResources; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; -import io.netty.channel.local.LocalAddress; import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.resolver.ResolvedAddressTypes; @@ -38,16 +40,12 @@ import jakarta.servlet.DispatcherType; import jakarta.servlet.Filter; import jakarta.servlet.ServletRegistration; import java.io.ByteArrayInputStream; -import java.io.FileInputStream; +import java.net.InetSocketAddress; import java.net.http.HttpClient; import java.nio.charset.StandardCharsets; -import java.security.KeyStore; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; import java.time.Clock; import java.time.Duration; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.EnumSet; import java.util.List; @@ -62,7 +60,6 @@ import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.function.Function; import java.util.stream.Stream; -import javax.annotation.Nullable; import org.eclipse.jetty.websocket.core.WebSocketComponents; import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; @@ -154,12 +151,8 @@ import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService; import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService; import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; import org.whispersystems.textsecuregcm.grpc.ValidatingInterceptor; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup; -import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer; +import org.whispersystems.textsecuregcm.grpc.net.ManagedGrpcServer; import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectTunnelServer; -import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer; import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer; import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; @@ -259,9 +252,9 @@ import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager; import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager; import org.whispersystems.textsecuregcm.util.BufferingInterceptor; -import org.whispersystems.textsecuregcm.util.ResilienceUtil; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.ManagedExecutors; +import org.whispersystems.textsecuregcm.util.ResilienceUtil; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; @@ -636,9 +629,8 @@ public class WhisperServerService extends Application dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion()); SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, retryExecutor, config.getSecureStorageServiceConfiguration()); - final GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager(); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, - grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor); + disconnectionRequestListenerExecutor, retryExecutor); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client, config.getCdnConfiguration().bucket()); MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, @@ -828,127 +820,63 @@ public class WhisperServerService extends Application serverBuilder) { - // Note: interceptors run in the reverse order they are added; the remote deprecation filter - // depends on the user-agent context so it has to come first here! - // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor- - serverBuilder - .intercept( - new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(), - config.getExternalRequestFilterConfiguration().grpcMethods())) - .intercept(validatingInterceptor) - .intercept(metricServerInterceptor) - .intercept(errorMappingInterceptor) - .intercept(remoteDeprecationFilter) - .intercept(requestAttributesInterceptor) - .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)) - .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters)) - .addService(new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC())) - .addService(new PaymentsGrpcService(currencyManager)) - .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config)) - .addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams)); - } - }; + final List authenticatedServices = Stream.of( + new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager), + ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters), + new KeysGrpcService(accountsManager, keysManager, rateLimiters), + new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, + config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations)) + .map(bindableService -> ServerInterceptors.intercept(bindableService, + // Note: interceptors run in the reverse order they are added; the remote deprecation filter + // depends on the user-agent context so it has to come first here! + validatingInterceptor, + metricServerInterceptor, + errorMappingInterceptor, + remoteDeprecationFilter, + requestAttributesInterceptor, + requireAuthenticationInterceptor)) + .toList(); - final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) { - @Override - protected void configureServer(final ServerBuilder serverBuilder) { - // Note: interceptors run in the reverse order they are added; the remote deprecation filter - // depends on the user-agent context so it has to come first here! - // http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor- - serverBuilder - .intercept(validatingInterceptor) - .intercept(metricServerInterceptor) - .intercept(errorMappingInterceptor) - .intercept(remoteDeprecationFilter) - .intercept(requestAttributesInterceptor) - .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)) - .addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager)) - .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters)) - .addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters)) - .addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, - config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations)); - } - }; + final List unauthenticatedServices = Stream.of( + new AccountsAnonymousGrpcService(accountsManager, rateLimiters), + new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC()), + new PaymentsGrpcService(currencyManager), + ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config), + new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams)) + .map(bindableService -> ServerInterceptors.intercept(bindableService, + // Note: interceptors run in the reverse order they are added; the remote deprecation filter + // depends on the user-agent context so it has to come first here! + grpcExternalRequestFilter, + validatingInterceptor, + metricServerInterceptor, + errorMappingInterceptor, + remoteDeprecationFilter, + requestAttributesInterceptor, + prohibitAuthenticationInterceptor)) + .toList(); - @Nullable final X509Certificate[] noiseWebSocketTlsCertificateChain; - @Nullable final PrivateKey noiseWebSocketTlsPrivateKey; + final ServerBuilder serverBuilder = + NettyServerBuilder.forAddress(new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port())); + authenticatedServices.forEach(serverBuilder::addService); + unauthenticatedServices.forEach(serverBuilder::addService); + final ManagedGrpcServer exposedGrpcServer = new ManagedGrpcServer(serverBuilder.build()); - if (config.getNoiseTunnelConfiguration().tlsKeyStoreFile() != null && - config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias() != null && - config.getNoiseTunnelConfiguration().tlsKeyStorePassword() != null) { - - try (final FileInputStream websocketNoiseTunnelTlsKeyStoreInputStream = new FileInputStream(config.getNoiseTunnelConfiguration().tlsKeyStoreFile())) { - final KeyStore keyStore = KeyStore.getInstance("PKCS12"); - keyStore.load(websocketNoiseTunnelTlsKeyStoreInputStream, config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray()); - - final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) keyStore.getEntry( - config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias(), - new KeyStore.PasswordProtection(config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray())); - - noiseWebSocketTlsCertificateChain = - Arrays.copyOf(privateKeyEntry.getCertificateChain(), privateKeyEntry.getCertificateChain().length, X509Certificate[].class); - - noiseWebSocketTlsPrivateKey = privateKeyEntry.getPrivateKey(); - } - } else { - noiseWebSocketTlsCertificateChain = null; - noiseWebSocketTlsPrivateKey = null; - } - - final ExecutorService noiseWebSocketDelegatedTaskExecutor = ExecutorServiceBuilder.of(environment, "noiseWebsocketDelegatedTask") - .minThreads(8) - .maxThreads(8) - .allowCoreThreadTimeOut(false) - .build(); - - final ManagedNioEventLoopGroup noiseTunnelEventLoopGroup = new ManagedNioEventLoopGroup(); - - final NoiseWebSocketTunnelServer noiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer( - config.getNoiseTunnelConfiguration().webSocketPort(), - noiseWebSocketTlsCertificateChain, - noiseWebSocketTlsPrivateKey, - noiseTunnelEventLoopGroup, - noiseWebSocketDelegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - config.getNoiseTunnelConfiguration().noiseStaticKeyPair(), - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress, - config.getNoiseTunnelConfiguration().recognizedProxySecret().value()); - - final NoiseDirectTunnelServer noiseDirectTunnelServer = new NoiseDirectTunnelServer( - config.getNoiseTunnelConfiguration().directPort(), - noiseTunnelEventLoopGroup, - grpcClientConnectionManager, - clientPublicKeysManager, - config.getNoiseTunnelConfiguration().noiseStaticKeyPair(), - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress); - - environment.lifecycle().manage(localEventLoopGroup); environment.lifecycle().manage(dnsResolutionEventLoopGroup); - environment.lifecycle().manage(anonymousGrpcServer); - environment.lifecycle().manage(authenticatedGrpcServer); - environment.lifecycle().manage(noiseTunnelEventLoopGroup); - environment.lifecycle().manage(noiseWebSocketTunnelServer); - environment.lifecycle().manage(noiseDirectTunnelServer); + environment.lifecycle().manage(exposedGrpcServer); final List filters = new ArrayList<>(); filters.add(remoteDeprecationFilter); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java index 04e7b0c51..564732386 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManager.java @@ -27,8 +27,6 @@ import java.util.concurrent.ScheduledExecutorService; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; @@ -47,7 +45,6 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil; public class DisconnectionRequestManager extends RedisPubSubAdapter implements Managed { private final FaultTolerantRedisClient pubSubClient; - private final GrpcClientConnectionManager grpcClientConnectionManager; private final Executor listenerEventExecutor; private final ScheduledExecutorService retryExecutor; @@ -74,12 +71,10 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter { - grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(accountIdentifier, deviceId)); - listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList()) .forEach(listener -> listenerEventExecutor.execute(() -> { try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java deleted file mode 100644 index 93530a002..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptor.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.whispersystems.textsecuregcm.auth.grpc; - -import io.grpc.ServerCall; -import io.grpc.ServerInterceptor; -import java.util.Optional; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; - -abstract class AbstractAuthenticationInterceptor implements ServerInterceptor { - - private final GrpcClientConnectionManager grpcClientConnectionManager; - - AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { - this.grpcClientConnectionManager = grpcClientConnectionManager; - } - - protected Optional getAuthenticatedDevice(final ServerCall call) - throws ChannelNotFoundException { - - return grpcClientConnectionManager.getAuthenticatedDevice(call); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java index b14465d51..c0ec26e82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptor.java @@ -1,40 +1,30 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ package org.whispersystems.textsecuregcm.auth.grpc; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; /** * A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not - * originate from a channel that is associated with an authenticated device. Calls with an associated authenticated - * device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined - * (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will - * reject the call with a status of {@code UNAVAILABLE}. + * contain an authorization header in the request metdata. Calls with an associated authenticated device are closed with + * an {@code UNAUTHENTICATED} status. */ -public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor { - - public ProhibitAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { - super(grpcClientConnectionManager); - } +public class ProhibitAuthenticationInterceptor implements ServerInterceptor { @Override public ServerCall.Listener interceptCall(final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { - - try { - return getAuthenticatedDevice(call) - // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited - // service via an authenticated connection, then that's actually a server configuration issue and not a - // problem with the client's request. - .map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL)) - .orElseGet(() -> next.startCall(call, headers)); - } catch (final ChannelNotFoundException e) { - return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); + final Metadata headers, final ServerCallHandler next) { + final String authHeaderString = headers.get(Metadata.Key.of(RequireAuthenticationInterceptor.AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER)); + if (authHeaderString != null) { + call.close(Status.UNAUTHENTICATED.withDescription("authorization header forbidden"), new Metadata()); + return new ServerCall.Listener<>() {}; } + return next.startCall(call, headers); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java index f03052bab..14ebec7cb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptor.java @@ -1,43 +1,68 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ package org.whispersystems.textsecuregcm.auth.grpc; +import io.dropwizard.auth.basic.BasicCredentials; import io.grpc.Context; import io.grpc.Contexts; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; +import java.util.Optional; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import org.whispersystems.textsecuregcm.util.HeaderUtils; /** - * A "require authentication" interceptor requires that requests be issued from a connection that is associated with an - * authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED} - * status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed - * before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}. + * A "require authentication" interceptor authenticates requests and attaches the {@link AuthenticatedDevice} to the + * current gRPC context. Calls without authentication or with invalid credentials are closed with an + * {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined (i.e. because the accounts + * database is unavailable), the interceptor will reject the call with a status of {@code UNAVAILABLE}. */ -public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor { +public class RequireAuthenticationInterceptor implements ServerInterceptor { - public RequireAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { - super(grpcClientConnectionManager); + static final String AUTHORIZATION_HEADER = "authorization"; + + private final AccountAuthenticator authenticator; + + public RequireAuthenticationInterceptor(final AccountAuthenticator authenticator) { + this.authenticator = authenticator; } @Override public ServerCall.Listener interceptCall(final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { + final Metadata headers, final ServerCallHandler next) { + final String authHeaderString = headers.get( + Metadata.Key.of(AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER)); - try { - return getAuthenticatedDevice(call) - .map(authenticatedDevice -> Contexts.interceptCall(Context.current() - .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), - call, headers, next)) - // Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required - // service via an unauthenticated connection, then that's actually a server configuration issue and not a - // problem with the client's request. - .orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL)); - } catch (final ChannelNotFoundException e) { - return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); + if (authHeaderString == null) { + return ServerInterceptorUtil.closeWithStatus(call, + Status.UNAUTHENTICATED.withDescription("missing authorization header")); } + + final Optional basicCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeaderString); + if (basicCredentials.isEmpty()) { + return ServerInterceptorUtil.closeWithStatus(call, + Status.UNAUTHENTICATED.withDescription("malformed authorization header")); + } + + final Optional authenticated = + authenticator.authenticate(basicCredentials.get()); + if (authenticated.isEmpty()) { + return ServerInterceptorUtil.closeWithStatus(call, + Status.UNAUTHENTICATED.withDescription("invalid credentials")); + } + + final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice( + authenticated.get().accountIdentifier(), + authenticated.get().deviceId()); + + return Contexts.interceptCall(Context.current() + .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice), + call, headers, next); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java new file mode 100644 index 000000000..fb3c08742 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/GrpcConfiguration.java @@ -0,0 +1,15 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.configuration; + +import jakarta.validation.constraints.NotNull; + +public record GrpcConfiguration(@NotNull String bindAddress, @NotNull Integer port) { + public GrpcConfiguration { + if (bindAddress == null || bindAddress.isEmpty()) { + bindAddress = "localhost"; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseTunnelConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseTunnelConfiguration.java deleted file mode 100644 index 521bd0af9..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/NoiseTunnelConfiguration.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.whispersystems.textsecuregcm.configuration; - -import jakarta.validation.constraints.NotNull; -import jakarta.validation.constraints.Positive; -import javax.annotation.Nullable; -import org.signal.libsignal.protocol.InvalidKeyException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPrivateKey; -import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; -import org.whispersystems.textsecuregcm.configuration.secrets.SecretString; - -public record NoiseTunnelConfiguration(@Positive int webSocketPort, - @Positive int directPort, - @Nullable String tlsKeyStoreFile, - @Nullable String tlsKeyStoreEntryAlias, - @Nullable SecretString tlsKeyStorePassword, - @NotNull SecretBytes noiseStaticPrivateKey, - @NotNull SecretString recognizedProxySecret) { - - public ECKeyPair noiseStaticKeyPair() throws InvalidKeyException { - final ECPrivateKey privateKey = new ECPrivateKey(noiseStaticPrivateKey().value()); - - return new ECKeyPair(privateKey.publicKey(), privateKey); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java deleted file mode 100644 index 24cec57ed..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptor.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import io.grpc.Context; -import io.grpc.ForwardingServerCallListener; -import io.grpc.Grpc; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.netty.channel.local.LocalAddress; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; - -/** - * Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with - * {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise. - */ -public class ChannelShutdownInterceptor implements ServerInterceptor { - - private final GrpcClientConnectionManager grpcClientConnectionManager; - - public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { - this.grpcClientConnectionManager = grpcClientConnectionManager; - } - - @Override - public ServerCall.Listener interceptCall(final ServerCall call, - final Metadata headers, - final ServerCallHandler next) { - - if (!grpcClientConnectionManager.handleServerCallStart(call)) { - // Don't allow new calls if the connection is getting ready to close - return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); - } - - return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) { - @Override - public void onComplete() { - grpcClientConnectionManager.handleServerCallComplete(call); - super.onComplete(); - } - - @Override - public void onCancel() { - grpcClientConnectionManager.handleServerCallComplete(call); - super.onCancel(); - } - }; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java index b4fb0d169..9c468ec0d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptor.java @@ -1,41 +1,108 @@ package org.whispersystems.textsecuregcm.grpc; +import com.google.common.net.HttpHeaders; import io.grpc.Context; import io.grpc.Contexts; +import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.Status; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import javax.annotation.Nullable; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * The request attributes interceptor makes request attributes from the underlying remote channel available to service + * The request attributes interceptor makes common request attributes from call metadata available to service * implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}. - * All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if - * request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started). * * @see RequestAttributesUtil */ public class RequestAttributesInterceptor implements ServerInterceptor { - private final GrpcClientConnectionManager grpcClientConnectionManager; + private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class); - public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) { - this.grpcClientConnectionManager = grpcClientConnectionManager; - } + private static final Metadata.Key ACCEPT_LANG_KEY = + Metadata.Key.of(HttpHeaders.ACCEPT_LANGUAGE, Metadata.ASCII_STRING_MARSHALLER); + + private static final Metadata.Key USER_AGENT_KEY = + Metadata.Key.of(HttpHeaders.USER_AGENT, Metadata.ASCII_STRING_MARSHALLER); + + private static final Metadata.Key X_FORWARDED_FOR_KEY = + Metadata.Key.of(HttpHeaders.X_FORWARDED_FOR, Metadata.ASCII_STRING_MARSHALLER); @Override public ServerCall.Listener interceptCall(final ServerCall call, final Metadata headers, final ServerCallHandler next) { + final String userAgentHeader = headers.get(USER_AGENT_KEY); + final String acceptLanguageHeader = headers.get(ACCEPT_LANG_KEY); + final String xForwardedForHeader = headers.get(X_FORWARDED_FOR_KEY); - try { - return Contexts.interceptCall(Context.current() - .withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, - grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next); - } catch (final ChannelNotFoundException e) { + final Optional remoteAddress = getMostRecentProxy(xForwardedForHeader) + .flatMap(mostRecentProxy -> { + try { + return Optional.of(InetAddress.ofLiteral(mostRecentProxy)); + } catch (IllegalArgumentException e) { + log.warn("Failed to parse most recent proxy {} as an IP address", mostRecentProxy, e); + return Optional.empty(); + } + }) + .or(() -> { + log.warn("No usable X-Forwarded-For header present, using remote socket address"); + final SocketAddress socketAddress = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + if (socketAddress == null || !(socketAddress instanceof InetSocketAddress inetAddress)) { + log.warn("Remote socket address not present or is not an inet address: {}", socketAddress); + return Optional.empty(); + } + return Optional.of(inetAddress.getAddress()); + }); + if (!remoteAddress.isPresent()) { return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE); } + + @Nullable List acceptLanguages = Collections.emptyList(); + if (StringUtils.isNotBlank(acceptLanguageHeader)) { + try { + acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader); + } catch (final IllegalArgumentException e) { + log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); + } + } + + final RequestAttributes requestAttributes = + new RequestAttributes(remoteAddress.get(), userAgentHeader, acceptLanguages); + return Contexts.interceptCall( + Context.current().withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), + call, headers, next); } + + /** + * Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header. + * + * @param forwardedFor the value of an X-Forwarded-For header + * @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or + * {@code forwardedFor} was null + * @see X-Forwarded-For - HTTP | + * MDN + */ + public static Optional getMostRecentProxy(@Nullable final String forwardedFor) { + return Optional.ofNullable(forwardedFor) + .map(ff -> { + final int idx = forwardedFor.lastIndexOf(',') + 1; + return idx < forwardedFor.length() + ? forwardedFor.substring(idx).trim() + : null; + }) + .filter(StringUtils::isNotBlank); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java deleted file mode 100644 index a65c3f5aa..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ErrorHandler.java +++ /dev/null @@ -1,41 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import javax.crypto.BadPaddingException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; - -/** - * An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions - * thrown in inbound handlers into {@link OutboundCloseErrorMessage}s. - */ -public class ErrorHandler extends ChannelInboundHandlerAdapter { - private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class); - - private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage( - OutboundCloseErrorMessage.Code.NOISE_ERROR, - "Noise encryption error"); - - @Override - public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { - final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) { - case NoiseHandshakeException e -> new OutboundCloseErrorMessage( - OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR, - e.getMessage()); - case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; - case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE; - default -> { - log.warn("An unexpected exception reached the end of the pipeline", cause); - yield new OutboundCloseErrorMessage( - OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR, - cause.getMessage()); - } - }; - - context.writeAndFlush(closeMessage) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java deleted file mode 100644 index 77c966b13..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/EstablishLocalGrpcConnectionHandler.java +++ /dev/null @@ -1,128 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Tags; -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.util.ReferenceCountUtil; -import java.net.InetAddress; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; - -/** - * An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering - * any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC - * server. - */ -public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter { - private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class); - - private final GrpcClientConnectionManager grpcClientConnectionManager; - - private final LocalAddress authenticatedGrpcServerAddress; - private final LocalAddress anonymousGrpcServerAddress; - private final FramingType framingType; - - private final List pendingReads = new ArrayList<>(); - - private static final String CONNECTION_ESTABLISHED_COUNTER_NAME = MetricsUtil.name(EstablishLocalGrpcConnectionHandler.class, "established"); - - public EstablishLocalGrpcConnectionHandler(final GrpcClientConnectionManager grpcClientConnectionManager, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final FramingType framingType) { - - this.grpcClientConnectionManager = grpcClientConnectionManager; - - this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress; - this.anonymousGrpcServerAddress = anonymousGrpcServerAddress; - this.framingType = framingType; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - pendingReads.add(message); - } - - @Override - public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) { - if (event instanceof NoiseIdentityDeterminedEvent( - final Optional authenticatedDevice, - InetAddress remoteAddress, String userAgent, String acceptLanguage)) { - // We assume that we'll only get a completed handshake event if the handshake met all authentication requirements - // for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to - // connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the - // authenticated service. - final LocalAddress grpcServerAddress = authenticatedDevice.isPresent() - ? authenticatedGrpcServerAddress - : anonymousGrpcServerAddress; - - GrpcClientConnectionManager.handleHandshakeInitiated( - remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage); - - final List tags = UserAgentTagUtil.getLibsignalAndPlatformTags(userAgent); - Metrics.counter(CONNECTION_ESTABLISHED_COUNTER_NAME, Tags.of(tags) - .and("authenticated", Boolean.toString(authenticatedDevice.isPresent())) - .and("framingType", framingType.name())) - .increment(); - - new Bootstrap() - .remoteAddress(grpcServerAddress) - .channel(LocalChannel.class) - .group(remoteChannelContext.channel().eventLoop()) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(final LocalChannel localChannel) { - localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel())); - } - }) - .connect() - .addListener((ChannelFutureListener) localChannelFuture -> { - if (localChannelFuture.isSuccess()) { - grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(), - remoteChannelContext.channel(), - authenticatedDevice); - - // Close the local connection if the remote channel closes and vice versa - remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close()); - localChannelFuture.channel().closeFuture().addListener(closeFuture -> - remoteChannelContext.channel() - .write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed")) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE)); - - remoteChannelContext.pipeline() - .addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel())); - - // Flush any buffered reads we accumulated while waiting to open the connection - pendingReads.forEach(remoteChannelContext::fireChannelRead); - pendingReads.clear(); - - remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this); - } else { - log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause()); - remoteChannelContext.close(); - } - }); - } - - remoteChannelContext.fireUserEventTriggered(event); - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - pendingReads.forEach(ReferenceCountUtil::release); - pendingReads.clear(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java deleted file mode 100644 index 0343b8e9a..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/FramingType.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -public enum FramingType { - NOISE_DIRECT, - WEBSOCKET -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java deleted file mode 100644 index ad12da9fe..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManager.java +++ /dev/null @@ -1,268 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.Grpc; -import io.grpc.ServerCall; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.util.AttributeKey; -import java.net.InetAddress; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.RequestAttributes; -import org.whispersystems.textsecuregcm.util.ClosableEpoch; - -/** - * A client connection manager associates a local connection to a local gRPC server with a remote connection through a - * Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated - * identity of the device that opened the connection (for non-anonymous connections). It can also close connections - * associated with a given device if that device's credentials have changed and clients must reauthenticate. - *

- * In general, all {@link ServerCall}s must have a local address that in turn should be resolvable to - * a remote channel, which must have associated request attributes and authentication status. It is possible - * that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the - * narrow window between a server call being created and the start of call execution, in which case accessor methods - * in this class will throw a {@link ChannelNotFoundException}. - *

- * A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to - * identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s. - * Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may - * be called from any application code. - */ -public class GrpcClientConnectionManager { - - private final Map remoteChannelsByLocalAddress = new ConcurrentHashMap<>(); - private final Map> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>(); - - @VisibleForTesting - static final AttributeKey AUTHENTICATED_DEVICE_ATTRIBUTE_KEY = - AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice"); - - @VisibleForTesting - public static final AttributeKey REQUEST_ATTRIBUTES_KEY = - AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes"); - - @VisibleForTesting - static final AttributeKey EPOCH_ATTRIBUTE_KEY = - AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch"); - - private static final OutboundCloseErrorMessage SERVER_CLOSED = - new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"); - - private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class); - - /** - * Returns the authenticated device associated with the given server call, if any. If the connection is anonymous - * (i.e. unauthenticated), the returned value will be empty. - * - * @param serverCall the gRPC server call for which to find an authenticated device - * - * @return the authenticated device associated with the given local address, if any - * - * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this - * generally indicates that the channel has closed while request processing is still in progress - */ - public Optional getAuthenticatedDevice(final ServerCall serverCall) - throws ChannelNotFoundException { - - return getAuthenticatedDevice(getRemoteChannel(serverCall)); - } - - @VisibleForTesting - Optional getAuthenticatedDevice(final Channel remoteChannel) { - return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get()); - } - - /** - * Returns the request attributes associated with the given server call. - * - * @param serverCall the gRPC server call for which to retrieve request attributes - * - * @return the request attributes associated with the given server call - * - * @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this - * generally indicates that the channel has closed while request processing is still in progress - */ - public RequestAttributes getRequestAttributes(final ServerCall serverCall) throws ChannelNotFoundException { - return getRequestAttributes(getRemoteChannel(serverCall)); - } - - @VisibleForTesting - RequestAttributes getRequestAttributes(final Channel remoteChannel) { - final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get(); - - if (requestAttributes == null) { - throw new IllegalStateException("Channel does not have request attributes"); - } - - return requestAttributes; - } - - /** - * Handles the start of a server call, incrementing the active call count for the remote channel associated with the - * given server call. - * - * @param serverCall the server call to start - * - * @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the - * underlying channel is closing - */ - public boolean handleServerCallStart(final ServerCall serverCall) { - try { - return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive(); - } catch (final ChannelNotFoundException e) { - // This would only happen if the channel had already closed, which is certainly possible. In this case, the call - // should certainly not proceed. - return false; - } - } - - /** - * Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel - * associated with the given server call. - * - * @param serverCall the server call to complete - */ - public void handleServerCallComplete(final ServerCall serverCall) { - try { - getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart(); - } catch (final ChannelNotFoundException ignored) { - // In practice, we'd only get here if the channel has already closed, so we can just ignore the exception - } - } - - /** - * Closes any client connections to this host associated with the given authenticated device. - * - * @param authenticatedDevice the authenticated device for which to close connections - */ - public void closeConnection(final AuthenticatedDevice authenticatedDevice) { - // Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid - // concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself - // from the list while we're still iterating, resulting in a `ConcurrentModificationException`. - final List channelsToClose = - new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList())); - - channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close()); - } - - private static void closeRemoteChannel(final Channel channel) { - channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - } - - @VisibleForTesting - @Nullable List getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) { - return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice); - } - - private Channel getRemoteChannel(final ServerCall serverCall) throws ChannelNotFoundException { - return getRemoteChannel(getLocalAddress(serverCall)); - } - - @VisibleForTesting - Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException { - final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress); - - if (remoteChannel == null) { - throw new ChannelNotFoundException(); - } - - return remoteChannelsByLocalAddress.get(localAddress); - } - - private static LocalAddress getLocalAddress(final ServerCall serverCall) { - // In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel. - // The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to - // a proxied Noise channel. - if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) { - throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); - } - - return localAddress; - } - - /** - * Handles receipt of a handshake message and associates attributes and headers from the handshake - * request with the channel via which the handshake took place. - * - * @param channel the channel where the handshake was initiated - * @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake - * @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null} - * @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be - * {@code null} - */ - public static void handleHandshakeInitiated(final Channel channel, - final InetAddress preferredRemoteAddress, - @Nullable final String userAgentHeader, - @Nullable final String acceptLanguageHeader) { - - @Nullable List acceptLanguages = Collections.emptyList(); - - if (StringUtils.isNotBlank(acceptLanguageHeader)) { - try { - acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader); - } catch (final IllegalArgumentException e) { - log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e); - } - } - - channel.attr(REQUEST_ATTRIBUTES_KEY) - .set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages)); - } - - /** - * Handles successful establishment of a Noise connection from a remote client to a local gRPC server. - * - * @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server - * @param remoteChannel the channel from the remote client to the Noise tunnel - * @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection - */ - void handleConnectionEstablished(final LocalChannel localChannel, - final Channel remoteChannel, - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) { - - maybeAuthenticatedDevice.ifPresent(authenticatedDevice -> - remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice)); - - remoteChannel.attr(EPOCH_ATTRIBUTE_KEY) - .set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel))); - - remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel); - - getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> - remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> { - final List channels = existingChannelList != null ? existingChannelList : new ArrayList<>(); - channels.add(remoteChannel); - - return channels; - })); - - remoteChannel.closeFuture().addListener(closeFuture -> { - remoteChannelsByLocalAddress.remove(localChannel.localAddress()); - - getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice -> - remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> { - if (existingChannelList == null) { - return null; - } - - existingChannelList.remove(remoteChannel); - - return existingChannelList.isEmpty() ? null : existingChannelList; - })); - }); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java deleted file mode 100644 index 2bd807967..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandler.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has - * either handled a proxy protocol message or determined that no such message is coming. - */ -public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter { - - private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class); - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof HAProxyMessage haProxyMessage) { - // Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still - // need to clear it from the pipeline to avoid confusing the TLS machinery, though. - log.debug("Discarding HAProxy message: {}", haProxyMessage); - haProxyMessage.release(); - } else { - super.channelRead(context, message); - } - - // Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first - // message, all others will just be "normal" messages, and our work here is done. - context.pipeline().remove(this); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java deleted file mode 100644 index cfd0205fc..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/HandshakePattern.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -public enum HandshakePattern { - NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"), - IK("Noise_IK_25519_ChaChaPoly_BLAKE2b"); - - private final String protocol; - - public String protocol() { - return protocol; - } - - - HandshakePattern(String protocol) { - this.protocol = protocol; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java deleted file mode 100644 index 5888174c4..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedDefaultEventLoopGroup.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.dropwizard.lifecycle.Managed; -import io.netty.channel.DefaultEventLoopGroup; - -/** - * A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing - * Dropwizard to manage the lifecycle of the event loop group. - */ -public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed { - - @Override - public void stop() throws InterruptedException { - this.shutdownGracefully().await(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java new file mode 100644 index 000000000..87ac886da --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedGrpcServer.java @@ -0,0 +1,29 @@ +package org.whispersystems.textsecuregcm.grpc.net; + +import io.dropwizard.lifecycle.Managed; +import io.grpc.Server; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +public class ManagedGrpcServer implements Managed { + private final Server server; + + public ManagedGrpcServer(Server server) { + this.server = server; + } + + @Override + public void start() throws IOException { + server.start(); + } + + @Override + public void stop() { + try { + server.shutdown().awaitTermination(5, TimeUnit.MINUTES); + } catch (final InterruptedException e) { + server.shutdownNow(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java deleted file mode 100644 index 21cd3d846..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ManagedLocalGrpcServer.java +++ /dev/null @@ -1,49 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.dropwizard.lifecycle.Managed; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.netty.NettyServerBuilder; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalServerChannel; -import java.io.IOException; -import java.util.concurrent.TimeUnit; - -/** - * A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress} - * and whose lifecycle is managed by Dropwizard via the {@link Managed} interface. - */ -public abstract class ManagedLocalGrpcServer implements Managed { - - private final Server server; - - public ManagedLocalGrpcServer(final LocalAddress localAddress, - final DefaultEventLoopGroup eventLoopGroup) { - - final ServerBuilder serverBuilder = NettyServerBuilder.forAddress(localAddress) - .channelType(LocalServerChannel.class) - .bossEventLoopGroup(eventLoopGroup) - .workerEventLoopGroup(eventLoopGroup); - - configureServer(serverBuilder); - - server = serverBuilder.build(); - } - - protected abstract void configureServer(final ServerBuilder serverBuilder); - - @Override - public void start() throws IOException { - server.start(); - } - - @Override - public void stop() { - try { - server.shutdown().awaitTermination(5, TimeUnit.MINUTES); - } catch (final InterruptedException e) { - server.shutdownNow(); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java deleted file mode 100644 index 55bc979c7..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseException.java +++ /dev/null @@ -1,13 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import org.whispersystems.textsecuregcm.util.NoStackTraceException; - -/** - * Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/ - * format or a general encryption error). - */ -public class NoiseException extends NoStackTraceException { - public NoiseException(final String message) { - super(message); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java deleted file mode 100644 index 42e871e35..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.CipherState; -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.concurrent.PromiseCombiner; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound - * messages - */ -public class NoiseHandler extends ChannelDuplexHandler { - - private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class); - private final CipherStatePair cipherStatePair; - - NoiseHandler(CipherStatePair cipherStatePair) { - this.cipherStatePair = cipherStatePair; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - try { - if (message instanceof ByteBuf frame) { - if (frame.readableBytes() > Noise.MAX_PACKET_LEN) { - throw new NoiseException("Invalid noise message length " + frame.readableBytes()); - } - // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. - // We'll need to copy it to a heap buffer. - handleInboundDataMessage(context, ByteBufUtil.getBytes(frame)); - } else { - // Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } finally { - ReferenceCountUtil.release(message); - } - } - - - private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes) - throws ShortBufferException, BadPaddingException { - final CipherState cipherState = cipherStatePair.getReceiver(); - // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer - final int plaintextLength = cipherState.decryptWithAd(null, - frameBytes, 0, - frameBytes, 0, - frameBytes.length); - - // Forward the decrypted plaintext along - context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength)); - } - - @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) - throws Exception { - if (message instanceof ByteBuf byteBuf) { - try { - // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames - final CipherState cipherState = cipherStatePair.getSender(); - - // Server message might not fit in a single noise packet, break it up into as many chunks as we need - final PromiseCombiner pc = new PromiseCombiner(context.executor()); - while (byteBuf.isReadable()) { - final ByteBuf plaintext = byteBuf.readSlice(Math.min( - // need room for a 16-byte AEAD tag - Noise.MAX_PACKET_LEN - 16, - byteBuf.readableBytes())); - - final int plaintextLength = plaintext.readableBytes(); - - // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the - // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a - // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. - final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; - plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); - - // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer - cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); - - pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer))); - } - pc.finish(promise); - } finally { - ReferenceCountUtil.release(byteBuf); - } - } else { - if (!(message instanceof OutboundCloseErrorMessage)) { - // Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames - // that get issued in response to exceptions) - log.warn("Unexpected object in pipeline: {}", message); - } - context.write(message, promise); - } - } - - @Override - public void handlerRemoved(ChannelHandlerContext var1) { - if (cipherStatePair != null) { - cipherStatePair.destroy(); - } - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java deleted file mode 100644 index b7a509728..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeException.java +++ /dev/null @@ -1,14 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import org.whispersystems.textsecuregcm.util.NoStackTraceException; - -/** - * Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or - * a general encryption error). - */ -public class NoiseHandshakeException extends NoStackTraceException { - - public NoiseHandshakeException(final String message) { - super(message); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java deleted file mode 100644 index 80d7f6105..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHandler.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufInputStream; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.util.ReferenceCountUtil; -import java.io.IOException; -import java.net.InetAddress; -import java.security.MessageDigest; -import java.util.Optional; -import java.util.UUID; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -/** - * Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will - * encrypt/decrypt subsequent data frames - *

- * The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator - * handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler - * currently supports two types of handshakes. - *

- * The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator - * handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in - * the handshake message must match the server's stored public key for the device identified by the provided ACI and - * deviceId. - *

- * The second are NK handshakes which are anonymous. - *

- * Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server - * to begin processing the request without an initial message delay (fast open). - *

- * Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake, - * this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This - * handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data - * frames. - */ -public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter { - - private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder() - .setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY) - .build().toByteArray(); - private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder() - .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK) - .build().toByteArray(); - - // We might get additional messages while we're waiting to process a handshake, so keep track of where we are - private boolean receivedHandshakeInit = false; - - private final ClientPublicKeysManager clientPublicKeysManager; - private final ECKeyPair ecKeyPair; - - public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) { - this.clientPublicKeysManager = clientPublicKeysManager; - this.ecKeyPair = ecKeyPair; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - try { - if (!(message instanceof NoiseHandshakeInit handshakeInit)) { - // Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - if (receivedHandshakeInit) { - throw new NoiseHandshakeException("Should not receive messages until handshake complete"); - } - receivedHandshakeInit = true; - - if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) { - throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes()); - } - - // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. - // We'll need to copy it to a heap buffer - handleInboundHandshake(context, - handshakeInit.getRemoteAddress(), - handshakeInit.getHandshakePattern(), - ByteBufUtil.getBytes(handshakeInit.content())); - } finally { - ReferenceCountUtil.release(message); - } - } - - private void handleInboundHandshake( - final ChannelHandlerContext context, - final InetAddress remoteAddress, - final HandshakePattern handshakePattern, - final byte[] frameBytes) throws NoiseHandshakeException { - final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair); - final ByteBuf payload = handshakeHelper.read(frameBytes); - - // Parse the handshake message - final NoiseTunnelProtos.HandshakeInit handshakeInit; - try { - handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload)); - } catch (IOException e) { - throw new NoiseHandshakeException("Failed to parse handshake message"); - } - - switch (handshakePattern) { - case NK -> { - if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) { - throw new NoiseHandshakeException("Anonymous handshake should not include identifiers"); - } - handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty()); - } - case IK -> { - final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey() - .orElseThrow(() -> new IllegalStateException("No remote public key")); - final UUID accountIdentifier = aci(handshakeInit); - final byte deviceId = deviceId(handshakeInit); - clientPublicKeysManager - .findPublicKey(accountIdentifier, deviceId) - .whenCompleteAsync((storedPublicKey, throwable) -> { - if (throwable != null) { - context.fireExceptionCaught(ExceptionUtils.unwrap(throwable)); - return; - } - final boolean valid = storedPublicKey - .map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes())) - .orElse(false); - if (!valid) { - // Write a handshake response indicating that the client used the wrong public key - final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK); - context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - - context.fireExceptionCaught(new NoiseHandshakeException("Bad public key")); - return; - } - handleAuthenticated(context, - handshakeHelper, remoteAddress, handshakeInit, - Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId))); - }, context.executor()); - } - }; - } - - private void handleAuthenticated(final ChannelHandlerContext context, - final NoiseHandshakeHelper handshakeHelper, - final InetAddress remoteAddress, - final NoiseTunnelProtos.HandshakeInit handshakeInit, - final Optional maybeAuthenticatedDevice) { - context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent( - maybeAuthenticatedDevice, - remoteAddress, - handshakeInit.getUserAgent(), - handshakeInit.getAcceptLanguage())); - - // Now that we've authenticated, write the handshake response - final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK); - context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage)) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - - // The handshake is complete. We can start intercepting read/write for noise encryption/decryption - // Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open - // request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already - // decrypted the fastOpen request - context.pipeline() - .addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split())); - context.pipeline().remove(NoiseHandshakeHandler.class); - if (!handshakeInit.getFastOpenRequest().isEmpty()) { - // The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll - // encrypt the response when the server writes back through us - context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer())); - } - } - - private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException { - try { - return UUIDUtil.fromByteString(handshakePayload.getAci()); - } catch (IllegalArgumentException e) { - throw new NoiseHandshakeException("Could not parse aci"); - } - } - - private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException { - if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) { - throw new NoiseHandshakeException("Invalid deviceId"); - } - return (byte) handshakePayload.getDeviceId(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java deleted file mode 100644 index e47120fb4..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelper.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.southernstorm.noise.protocol.HandshakeState; -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import java.security.NoSuchAlgorithmException; -import java.util.Optional; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; - -/** - * Helper for the responder of a 2-message handshake with a pre-shared responder static key - */ -class NoiseHandshakeHelper { - - private final static int AEAD_TAG_LENGTH = 16; - private final static int KEY_LENGTH = 32; - - private final HandshakePattern handshakePattern; - private final ECKeyPair serverStaticKeyPair; - private final HandshakeState handshakeState; - - NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) { - this.handshakePattern = handshakePattern; - this.serverStaticKeyPair = serverStaticKeyPair; - try { - this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER); - } catch (final NoSuchAlgorithmException e) { - throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e); - } - } - - /** - * Get the length of the initiator's keys - * - * @return length of the handshake message sent by the remote party (the initiator) not including the payload - */ - private int initiatorHandshakeMessageKeyLength() { - return switch (handshakePattern) { - // ephemeral key, static key (encrypted), AEAD tag for static key - case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH; - // ephemeral key only - case NK -> KEY_LENGTH; - }; - } - - HandshakeState getHandshakeState() { - return this.handshakeState; - } - - ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException { - if (handshakeState.getAction() != HandshakeState.NO_ACTION) { - throw new NoiseHandshakeException("Cannot send more data before handshake is complete"); - } - - // Length for an empty payload - final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH; - if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) { - throw new NoiseHandshakeException("Unexpected ephemeral key message length"); - } - - final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH; - - // Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is - // making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key - // from the supplied private key (and will in fact overwrite any previously-set public key when setting a private - // key), so we just set the private key here. - handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0); - handshakeState.start(); - - int payloadBytesRead; - - try { - payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length, - remoteHandshakeMessage, 0); - } catch (final ShortBufferException e) { - // This should never happen since we're checking the length of the frame up front - throw new NoiseHandshakeException("Unexpected client payload"); - } catch (final BadPaddingException e) { - // We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key - // or payload - throw new NoiseHandshakeException("Invalid keys or payload"); - } - if (payloadBytesRead != payloadLength) { - throw new NoiseHandshakeException( - "Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead); - } - return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead); - } - - byte[] write(byte[] payload) { - if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) { - throw new IllegalStateException("Cannot send data before handshake is complete"); - } - - // Currently only support handshake patterns where the server static key is known - // Send our ephemeral key and the response to the initiator with the encrypted payload - final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH]; - try { - int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length); - if (written != response.length) { - throw new IllegalStateException("Unexpected handshake response length"); - } - return response; - } catch (final ShortBufferException e) { - // This should never happen for messages of known length that we control - throw new IllegalStateException("Key material buffer was too short for message", e); - } - } - - Optional remotePublicKey() { - return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> { - final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()]; - handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0); - return publicKeyFromClient; - }); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java deleted file mode 100644 index d5368f366..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeInit.java +++ /dev/null @@ -1,33 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.DefaultByteBufHolder; -import java.net.InetAddress; - -/** - * A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata - * and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic. - */ -public class NoiseHandshakeInit extends DefaultByteBufHolder { - - private final InetAddress remoteAddress; - private final HandshakePattern handshakePattern; - - public NoiseHandshakeInit( - final InetAddress remoteAddress, - final HandshakePattern handshakePattern, - final ByteBuf initiatorHandshakeMessage) { - super(initiatorHandshakeMessage); - this.remoteAddress = remoteAddress; - this.handshakePattern = handshakePattern; - } - - public InetAddress getRemoteAddress() { - return remoteAddress; - } - - public HandshakePattern getHandshakePattern() { - return handshakePattern; - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java deleted file mode 100644 index 45344248d..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseIdentityDeterminedEvent.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import java.net.InetAddress; -import java.util.Optional; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; - -/** - * An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is - * connecting anonymously, the identity is empty, otherwise it will be present and already authenticated. - * - * @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a - * type that performs authentication - * @param remoteAddress the remote address of the connecting client - * @param userAgent the client supplied userAgent - * @param acceptLanguage the client supplied acceptLanguage - */ -public record NoiseIdentityDeterminedEvent( - Optional authenticatedDevice, - InetAddress remoteAddress, - String userAgent, - String acceptLanguage) { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java deleted file mode 100644 index 3ef4f84f1..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/OutboundCloseErrorMessage.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -/** - * An error written to the outbound pipeline that indicates the connection should be closed - */ -public record OutboundCloseErrorMessage(Code code, String message) { - public enum Code { - - /** - * The server decided to close the connection. This could be because the server is going away, or it could be - * because the credentials for the connected client have been updated. - */ - SERVER_CLOSED, - - /** - * There was a noise decryption error after the noise session was established - */ - NOISE_ERROR, - - /** - * There was an error establishing the noise handshake - */ - NOISE_HANDSHAKE_ERROR, - - INTERNAL_SERVER_ERROR - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java deleted file mode 100644 index 00afa516d..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyHandler.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; - -/** - * A proxy handler writes all data read from one channel to another peer channel. - */ -public class ProxyHandler extends ChannelInboundHandlerAdapter { - - private final Channel peerChannel; - - public ProxyHandler(final Channel peerChannel) { - this.peerChannel = peerChannel; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - peerChannel.writeAndFlush(message) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java deleted file mode 100644 index 3b8951658..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandler.java +++ /dev/null @@ -1,73 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import com.google.common.annotations.VisibleForTesting; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.haproxy.HAProxyMessageDecoder; - -/** - * A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection. - * If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline. - * In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol - * message, it will remove itself from the pipeline and pass any intercepted down the pipeline. - * - * @see The PROXY protocol - */ -public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter { - - private CompositeByteBuf accumulator; - - @VisibleForTesting - static final int PROXY_MESSAGE_DETECTION_BYTES = 12; - - @Override - public void handlerAdded(final ChannelHandlerContext context) { - // We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get - // non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every - // practical case, though, we'll be able to tell from the first packet. - accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof ByteBuf byteBuf) { - accumulator.addComponent(true, byteBuf); - - switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) { - case NEEDS_MORE_DATA -> { - } - - case INVALID -> { - // We have enough information to determine that this connection is NOT starting with a proxy protocol message, - // and we can just pass the accumulated bytes through - context.fireChannelRead(accumulator); - - accumulator = null; - context.pipeline().remove(this); - } - - case DETECTED -> { - // We have enough information to know that we're dealing with a proxy protocol message; add appropriate - // handlers and pass the accumulated bytes through - context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder()); - context.fireChannelRead(accumulator); - - accumulator = null; - context.pipeline().remove(this); - } - } - } else { - super.channelRead(context, message); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - if (accumulator != null) { - accumulator.release(); - accumulator = null; - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java deleted file mode 100644 index 0ed2f938b..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectDataFrameCodec.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.util.ReferenceCountUtil; -import org.whispersystems.textsecuregcm.grpc.net.NoiseException; - -/** - * In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the - * noise packet to the noise layer as a {@link ByteBuf} for decryption. - *

- * In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper - * so it can be wire serialized. This handler assumes the first outbound message received will correspond to the - * handshake response, and then the subsequent messages are all data frame payloads. - */ -public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler { - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (msg instanceof NoiseDirectFrame frame) { - if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) { - ReferenceCountUtil.release(msg); - throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType()); - } - ctx.fireChannelRead(frame.content()); - } else { - ctx.fireChannelRead(msg); - } - } - - @Override - public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { - if (msg instanceof ByteBuf bb) { - ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise); - } else { - ctx.write(msg, promise); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java deleted file mode 100644 index 3545d5797..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrame.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.DefaultByteBufHolder; - -public class NoiseDirectFrame extends DefaultByteBufHolder { - - static final byte VERSION = 0x00; - - private final FrameType frameType; - - public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) { - super(data); - this.frameType = frameType; - } - - public FrameType frameType() { - return frameType; - } - - public byte versionedFrameTypeByte() { - final byte frameBits = frameType().getFrameBits(); - return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits); - } - - - public enum FrameType { - /** - * The payload is the initiator message for a Noise NK handshake. If established, the - * session will be unauthenticated. - */ - NK_HANDSHAKE((byte) 1), - /** - * The payload is the initiator message for a Noise IK handshake. If established, the - * session will be authenticated. - */ - IK_HANDSHAKE((byte) 2), - /** - * The payload is an encrypted noise packet. - */ - DATA((byte) 3), - /** - * A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being - * closed. - */ - CLOSE((byte) 4); - - private final byte frameType; - - FrameType(byte frameType) { - if (frameType != (0x0F & frameType)) { - throw new IllegalStateException("Frame type must fit in 4 bits"); - } - this.frameType = frameType; - } - - public byte getFrameBits() { - return frameType; - } - - public boolean isHandshake() { - return switch (this) { - case IK_HANDSHAKE, NK_HANDSHAKE -> true; - case DATA, CLOSE -> false; - }; - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java deleted file mode 100644 index 4d94d0400..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectFrameCodec.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.util.ReferenceCountUtil; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; - -/** - * Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes - * have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder} - */ -public class NoiseDirectFrameCodec extends ChannelDuplexHandler { - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (msg instanceof ByteBuf byteBuf) { - try { - ctx.fireChannelRead(deserialize(byteBuf)); - } catch (Exception e) { - ReferenceCountUtil.release(byteBuf); - throw e; - } - } else { - ctx.fireChannelRead(msg); - } - } - - @Override - public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { - if (msg instanceof NoiseDirectFrame noiseDirectFrame) { - try { - // Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the - // network, nothing should have to make another copy of this. If later another layer is added, it may be more - // efficient to reuse the input buffer (typically not direct) by using a composite byte buffer - final ByteBuf serialized = serialize(ctx, noiseDirectFrame); - ctx.writeAndFlush(serialized, promise); - } finally { - ReferenceCountUtil.release(noiseDirectFrame); - } - } else { - ctx.write(msg, promise); - } - } - - private ByteBuf serialize( - final ChannelHandlerContext ctx, - final NoiseDirectFrame noiseDirectFrame) { - if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) { - throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes()); - } - - // 1 version/frametype byte, 2 length bytes, content - final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes()); - - byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte()); - byteBuf.writeShort(noiseDirectFrame.content().readableBytes()); - byteBuf.writeBytes(noiseDirectFrame.content()); - return byteBuf; - } - - private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception { - final byte versionAndFrameByte = byteBuf.readByte(); - final int version = (versionAndFrameByte & 0xF0) >> 4; - if (version != NoiseDirectFrame.VERSION) { - throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version); - } - final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F); - final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) { - case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE; - case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE; - case 3 -> NoiseDirectFrame.FrameType.DATA; - case 4 -> NoiseDirectFrame.FrameType.CLOSE; - default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits); - }; - - final int length = Short.toUnsignedInt(byteBuf.readShort()); - if (length != byteBuf.readableBytes()) { - throw new IllegalArgumentException( - "Payload length did not match remaining buffer, should have been guaranteed by a previous handler"); - } - return new NoiseDirectFrame(frameType, byteBuf.readSlice(length)); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java deleted file mode 100644 index bf72bf927..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectHandshakeSelector.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.util.ReferenceCountUtil; -import java.io.IOException; -import java.net.InetSocketAddress; -import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; - -/** - * Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and - * forwards the handshake frame along as a {@link NoiseHandshakeInit} message - */ -public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter { - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (msg instanceof NoiseDirectFrame frame) { - try { - if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) { - throw new IOException("Could not determine remote address"); - } - // We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler - // needs into a NoiseHandshakeInit message and forward that along - final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(), - switch (frame.frameType()) { - case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type"); - case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector"); - case IK_HANDSHAKE -> HandshakePattern.IK; - case NK_HANDSHAKE -> HandshakePattern.NK; - }, frame.content()); - - // Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames - // should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded - // for network serialization. Note that we need to install the Data frame handler before firing the read, - // because we may receive an outbound message from the noiseHandler - ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec()); - ctx.fireChannelRead(handshakeMessage); - } catch (Exception e) { - ReferenceCountUtil.release(msg); - throw e; - } - } else { - ctx.fireChannelRead(msg); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java deleted file mode 100644 index 1548bdb98..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectInboundCloseHandler.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.micrometer.core.instrument.Metrics; -import io.netty.buffer.ByteBufUtil; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.util.ReferenceCountUtil; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; - - -/** - * Watches for inbound close frames and closes the connection in response - */ -public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter { - private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "clientClose"); - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { - try { - final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason - .parseFrom(ByteBufUtil.getBytes(ndf.content())); - - Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment(); - } finally { - ReferenceCountUtil.release(msg); - ctx.close(); - } - } else { - ctx.fireChannelRead(msg); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java deleted file mode 100644 index b7b790e69..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectOutboundErrorHandler.java +++ /dev/null @@ -1,43 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.micrometer.core.instrument.Metrics; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufOutputStream; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; - -/** - * Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are - * written, the channel is closed - */ -class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter { - private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "serverClose"); - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof OutboundCloseErrorMessage err) { - final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) { - case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE; - case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR; - case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR; - case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR; - }; - Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "reason", code.name()).increment(); - - final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder() - .setCode(code) - .setMessage(err.message()) - .build(); - final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize()); - proto.writeTo(new ByteBufOutputStream(byteBuf)); - ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf)) - .addListener(ChannelFutureListener.CLOSE); - } else { - ctx.write(msg, promise); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java deleted file mode 100644 index ee3680826..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/NoiseDirectTunnelServer.java +++ /dev/null @@ -1,96 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import com.google.common.annotations.VisibleForTesting; -import com.southernstorm.noise.protocol.Noise; -import io.dropwizard.lifecycle.Managed; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.ServerSocketChannel; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; -import java.net.InetSocketAddress; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler; -import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler; -import org.whispersystems.textsecuregcm.grpc.net.FramingType; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler; -import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -/** - * A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom - * binary framing protocol) and passes it through to a local gRPC server. - */ -public class NoiseDirectTunnelServer implements Managed { - - private final ServerBootstrap bootstrap; - private ServerSocketChannel channel; - - private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class); - - public NoiseDirectTunnelServer(final int port, - final NioEventLoopGroup eventLoopGroup, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress) { - - this.bootstrap = new ServerBootstrap() - .group(eventLoopGroup) - .channel(NioServerSocketChannel.class) - .localAddress(port) - .childHandler(new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel socketChannel) { - socketChannel.pipeline() - .addLast(new ProxyProtocolDetectionHandler()) - .addLast(new HAProxyMessageHandler()) - // frame byte followed by a 2-byte length field - .addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2)) - // Parses NoiseDirectFrames from wire bytes and vice versa - .addLast(new NoiseDirectFrameCodec()) - // Terminate the connection if the client sends us a close frame - .addLast(new NoiseDirectInboundCloseHandler()) - // Turn generic OutboundCloseErrorMessages into noise direct error frames - .addLast(new NoiseDirectOutboundErrorHandler()) - // Forwards the first payload supplemented with handshake metadata, and then replaces itself with a - // NoiseDirectDataFrameCodec to handle subsequent data frames - .addLast(new NoiseDirectHandshakeSelector()) - // Performs the noise handshake and then replace itself with a NoiseHandler - .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair)) - // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler - // once the Noise handshake has completed - .addLast(new EstablishLocalGrpcConnectionHandler( - grpcClientConnectionManager, - authenticatedGrpcServerAddress, anonymousGrpcServerAddress, - FramingType.NOISE_DIRECT)) - .addLast(new ErrorHandler()); - } - }); - } - - @VisibleForTesting - public InetSocketAddress getLocalAddress() { - return channel.localAddress(); - } - - @Override - public void start() throws InterruptedException { - channel = (ServerSocketChannel) bootstrap.bind().await().channel(); - } - - @Override - public void stop() throws InterruptedException { - if (channel != null) { - channel.close().await(); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java deleted file mode 100644 index 67a24b99f..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/ApplicationWebSocketCloseReason.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; - -enum ApplicationWebSocketCloseReason { - NOISE_HANDSHAKE_ERROR(4001), - NOISE_ENCRYPTION_ERROR(4002); - - private final int statusCode; - - ApplicationWebSocketCloseReason(final int statusCode) { - this.statusCode = statusCode; - } - - public int getStatusCode() { - return statusCode; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java deleted file mode 100644 index eaff83bfa..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/NoiseWebSocketTunnelServer.java +++ /dev/null @@ -1,146 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import com.google.common.annotations.VisibleForTesting; -import com.southernstorm.noise.protocol.Noise; -import io.dropwizard.lifecycle.Managed; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.ServerSocketChannel; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioServerSocketChannel; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.HttpServerCodec; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import io.netty.handler.ssl.ClientAuth; -import io.netty.handler.ssl.OpenSsl; -import io.netty.handler.ssl.SslContext; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.handler.ssl.SslProtocols; -import io.netty.handler.ssl.SslProvider; -import java.net.InetSocketAddress; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.concurrent.Executor; -import javax.annotation.Nullable; -import javax.net.ssl.SSLException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.*; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -/** - * A Noise-over-WebSocket tunnel server accepts traffic from the public internet (in the form of Noise packets framed by - * binary WebSocket frames) and passes it through to a local gRPC server. - */ -public class NoiseWebSocketTunnelServer implements Managed { - - private final ServerBootstrap bootstrap; - private ServerSocketChannel channel; - - static final String AUTHENTICATED_SERVICE_PATH = "/authenticated"; - static final String ANONYMOUS_SERVICE_PATH = "/anonymous"; - static final String HEALTH_CHECK_PATH = "/health-check"; - - private static final Logger log = LoggerFactory.getLogger(NoiseWebSocketTunnelServer.class); - - public NoiseWebSocketTunnelServer(final int websocketPort, - @Nullable final X509Certificate[] tlsCertificateChain, - @Nullable final PrivateKey tlsPrivateKey, - final NioEventLoopGroup eventLoopGroup, - final Executor delegatedTaskExecutor, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair ecKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final String recognizedProxySecret) throws SSLException { - - @Nullable final SslContext sslContext; - - if (tlsCertificateChain != null && tlsPrivateKey != null) { - final SslProvider sslProvider; - - if (OpenSsl.isAvailable()) { - log.info("Native OpenSSL provider is available; will use native provider"); - sslProvider = SslProvider.OPENSSL; - } else { - log.info("No native SSL provider available; will use JDK provider"); - sslProvider = SslProvider.JDK; - } - - sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain) - .clientAuth(ClientAuth.NONE) - // Some load balancers require TLS 1.2 for health checks - .protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2) - .sslProvider(sslProvider) - .build(); - } else { - log.warn("No TLS credentials provided; Noise-over-WebSocket tunnel will not use TLS. This configuration is not suitable for production environments."); - sslContext = null; - } - - this.bootstrap = new ServerBootstrap() - .group(eventLoopGroup) - .channel(NioServerSocketChannel.class) - .localAddress(websocketPort) - .childHandler(new ChannelInitializer() { - @Override - protected void initChannel(SocketChannel socketChannel) { - socketChannel.pipeline() - .addLast(new ProxyProtocolDetectionHandler()) - .addLast(new HAProxyMessageHandler()); - - if (sslContext != null) { - socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor)); - } - - socketChannel.pipeline() - .addLast(new HttpServerCodec()) - .addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN)) - // The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade - // request and passed it down the pipeline - .addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH)) - .addLast(new WebSocketServerProtocolHandler("/", true)) - // Metrics on inbound/outbound Close frames - .addLast(new WebSocketCloseMetricHandler()) - // Turn generic OutboundCloseErrorMessages into websocket close frames - .addLast(new WebSocketOutboundErrorHandler()) - .addLast(new RejectUnsupportedMessagesHandler()) - .addLast(new WebsocketPayloadCodec()) - // The WebSocket handshake complete listener will forward the first payload supplemented with - // data from the websocket handshake completion event, and then remove itself from the pipeline - .addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret)) - // The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a - // NoiseHandler - .addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair)) - // This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler - // once the Noise handshake has completed - .addLast(new EstablishLocalGrpcConnectionHandler( - grpcClientConnectionManager, - authenticatedGrpcServerAddress, anonymousGrpcServerAddress, - FramingType.WEBSOCKET)) - .addLast(new ErrorHandler()); - } - }); - } - - @VisibleForTesting - InetSocketAddress getLocalAddress() { - return channel.localAddress(); - } - - @Override - public void start() throws InterruptedException { - channel = (ServerSocketChannel) bootstrap.bind().await().channel(); - } - - @Override - public void stop() throws InterruptedException { - if (channel != null) { - channel.close().await(); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java deleted file mode 100644 index d313951b3..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandler.java +++ /dev/null @@ -1,35 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import io.netty.util.ReferenceCountUtil; - -/** - * A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process. - */ -public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter { - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof final WebSocketFrame webSocketFrame) { - if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) { - try { - context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE)); - } finally { - textWebSocketFrame.release(); - } - } else { - // Allow all other types of WebSocket frames - context.fireChannelRead(webSocketFrame); - } - } else { - // Discard anything that's not a WebSocket frame - ReferenceCountUtil.release(message); - context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE)); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java deleted file mode 100644 index 52fa3de2d..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketCloseMetricHandler.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.micrometer.core.instrument.Metrics; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; - -public class WebSocketCloseMetricHandler extends ChannelDuplexHandler { - - private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "clientClose"); - private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "serverClose"); - - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (msg instanceof CloseWebSocketFrame closeFrame) { - Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "closeCode", validatedCloseCode(closeFrame.statusCode())).increment(); - } - ctx.fireChannelRead(msg); - } - - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof CloseWebSocketFrame closeFrame) { - Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "closeCode", Integer.toString(closeFrame.statusCode())).increment(); - } - ctx.write(msg, promise); - } - - private static String validatedCloseCode(int closeCode) { - - if (closeCode >= 1000 && closeCode <= 1015) { - // RFC-6455 pre-defined status codes - return Integer.toString(closeCode); - } else if (closeCode >= 4000 && closeCode <= 4100) { - // Application status codes - return Integer.toString(closeCode); - } else { - return "unknown"; - } - } - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java deleted file mode 100644 index e70cdec8e..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandler.java +++ /dev/null @@ -1,81 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.util.ReferenceCountUtil; - -/** - * A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully - * rejects requests for anything other than a WebSocket connection to a known endpoint. - */ -class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter { - - private final String authenticatedPath; - private final String anonymousPath; - private final String healthCheckPath; - - WebSocketOpeningHandshakeHandler(final String authenticatedPath, - final String anonymousPath, - final String healthCheckPath) { - - this.authenticatedPath = authenticatedPath; - this.anonymousPath = anonymousPath; - this.healthCheckPath = healthCheckPath; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof FullHttpRequest request) { - boolean shouldReleaseRequest = true; - - try { - if (request.decoderResult().isSuccess()) { - if (HttpMethod.GET.equals(request.method())) { - if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) { - if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) { - // Pass the request along to the websocket handshake handler and remove ourselves from the pipeline - shouldReleaseRequest = false; - - context.fireChannelRead(request); - context.pipeline().remove(this); - } else { - closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED); - } - } else if (healthCheckPath.equals(request.uri())) { - closeConnectionWithStatus(context, request, HttpResponseStatus.NO_CONTENT); - } else { - closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND); - } - } else { - closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED); - } - } else { - closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST); - } - } finally { - if (shouldReleaseRequest) { - request.release(); - } - } - } else { - // Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error - ReferenceCountUtil.release(message); - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } - - private static void closeConnectionWithStatus(final ChannelHandlerContext context, - final FullHttpRequest request, - final HttpResponseStatus status) { - - context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status)) - .addListener(ChannelFutureListener.CLOSE); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java deleted file mode 100644 index 421d575bb..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOutboundErrorHandler.java +++ /dev/null @@ -1,60 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; - -/** - * Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames - */ -class WebSocketOutboundErrorHandler extends ChannelDuplexHandler { - private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketOutboundErrorHandler.class, "serverClose"); - - private boolean websocketHandshakeComplete = false; - - private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class); - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception { - if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) { - setWebsocketHandshakeComplete(); - } - - context.fireUserEventTriggered(event); - } - - protected void setWebsocketHandshakeComplete() { - this.websocketHandshakeComplete = true; - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof OutboundCloseErrorMessage err) { - if (websocketHandshakeComplete) { - final int status = switch (err.code()) { - case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code(); - case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode(); - case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode(); - case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code(); - }; - ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - } else { - log.debug("Error {} occurred before websocket handshake complete", err); - // We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful - // way; just close the connection instead. - ctx.close(); - } - } else { - ctx.write(msg, promise); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java deleted file mode 100644 index 7093934bd..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandler.java +++ /dev/null @@ -1,152 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.net.InetAddresses; -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import io.netty.util.ReferenceCountUtil; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.util.Optional; -import javax.annotation.Nullable; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; - -/** - * A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate - * Noise handshake handler for the requested path. - */ -class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { - - private final byte[] recognizedProxySecret; - - private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class); - - @VisibleForTesting - static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy"; - - @VisibleForTesting - static final String FORWARDED_FOR_HEADER = "X-Forwarded-For"; - - private InetAddress remoteAddress = null; - private HandshakePattern handshakePattern = null; - - WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) { - - // The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex- - // encoded value). We convert it into a byte array here for easier constant-time comparisons via - // MessageDigest.equals() later. - this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8); - } - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { - final Optional maybePreferredRemoteAddress = - getPreferredRemoteAddress(context, handshakeCompleteEvent); - - if (maybePreferredRemoteAddress.isEmpty()) { - context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, - "Could not determine remote address")) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - - return; - } - - remoteAddress = maybePreferredRemoteAddress.get(); - handshakePattern = switch (handshakeCompleteEvent.requestUri()) { - case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK; - case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK; - // The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an - // internal error if something slipped through. - default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri()); - }; - } - - context.fireUserEventTriggered(event); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object msg) { - try { - if (!(msg instanceof ByteBuf frame)) { - throw new IllegalStateException("Unexpected msg type: " + msg.getClass()); - } - - if (handshakePattern == null || remoteAddress == null) { - throw new IllegalStateException("Received payload before websocket handshake complete"); - } - - final NoiseHandshakeInit handshakeMessage = - new NoiseHandshakeInit(remoteAddress, handshakePattern, frame); - - context.pipeline().remove(WebsocketHandshakeCompleteHandler.class); - context.fireChannelRead(handshakeMessage); - } catch (Exception e) { - ReferenceCountUtil.release(msg); - throw e; - } - } - - private Optional getPreferredRemoteAddress(final ChannelHandlerContext context, - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) { - - final byte[] recognizedProxySecretFromHeader = - handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "") - .getBytes(StandardCharsets.UTF_8); - - final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader); - - if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) { - final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER); - - return getMostRecentProxy(forwardedFor).map(mostRecentProxy -> { - try { - return InetAddresses.forString(mostRecentProxy); - } catch (final IllegalArgumentException e) { - log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e); - return null; - } - }); - } else { - // Either we don't trust the forwarded-for header or it's not present - if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) { - return Optional.of(inetSocketAddress.getAddress()); - } else { - log.warn("Channel's remote address was not an InetSocketAddress"); - return Optional.empty(); - } - } - } - - /** - * Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header. - * - * @param forwardedFor the value of an X-Forwarded-For header - * @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or - * {@code forwardedFor} was null - * @see X-Forwarded-For - HTTP | - * MDN - */ - @VisibleForTesting - static Optional getMostRecentProxy(@Nullable final String forwardedFor) { - return Optional.ofNullable(forwardedFor) - .map(ff -> { - final int idx = forwardedFor.lastIndexOf(',') + 1; - return idx < forwardedFor.length() - ? forwardedFor.substring(idx).trim() - : null; - }) - .filter(StringUtils::isNotBlank); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java deleted file mode 100644 index 4ef5bd151..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketPayloadCodec.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; - -/** - * Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a - * {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise - * packet buffers in BinaryWebsocketFrames for writing through the websocket layer. - */ -public class WebsocketPayloadCodec extends ChannelDuplexHandler { - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) { - if (msg instanceof BinaryWebSocketFrame frame) { - ctx.fireChannelRead(frame.content()); - } else { - ctx.fireChannelRead(msg); - } - } - - @Override - public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { - if (msg instanceof ByteBuf bb) { - ctx.write(new BinaryWebSocketFrame(bb), promise); - } else { - ctx.write(msg, promise); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index e5386896e..d283c0c3c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -15,11 +15,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; import java.time.Clock; -import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -43,7 +39,6 @@ import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; @@ -267,9 +262,8 @@ record CommandDependencies( () -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion()); SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, retryExecutor, configuration.getSecureStorageServiceConfiguration()); - GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager(); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, - grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor); + disconnectionRequestListenerExecutor, retryExecutor); MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, messageDeletionExecutor, retryExecutor, Clock.systemUTC(), experimentEnrollmentManager); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client, diff --git a/service/src/main/proto/NoiseDirect.proto b/service/src/main/proto/NoiseDirect.proto deleted file mode 100644 index a5c1ecf8c..000000000 --- a/service/src/main/proto/NoiseDirect.proto +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -syntax = "proto3"; - -option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect"; -option java_outer_classname = "NoiseDirectProtos"; - -message CloseReason { - enum Code { - UNSPECIFIED = 0; - // Indicates non-error termination - // Examples: - // - The client is finished with the connection - OK = 1; - - // There was an issue with the handshake. If sent after a handshake response, - // the response includes more information about the nature of the error - // Examples: - // - The client did not provide a handshake message - // - The client had incorrect authentication credentials. The handshake - // payload includes additional details - HANDSHAKE_ERROR = 2; - - // There was an encryption/decryption issue after the handshake - // Examples: - // - The client incorrectly encrypted a noise message and it had a bad - // AEAD tag - ENCRYPTION_ERROR = 3; - - // The server is temporarily unavailable, going away, or requires a - // connection reset - // Examples: - // - The server is shutting down - // - The client’s authentication credentials have been rotated - UNAVAILABLE = 4; - - // There was an an internal error - // Examples: - // - The server experienced a temporary database outage that prevented it - // from checking the client's credentials - INTERNAL_ERROR = 5; - } - - Code code = 1; - - // If present, includes details about the error. Implementations should never - // parse or otherwise implement conditional logic based on the contents of the - // error message string, it is for logging and debugging purposes only. - string message = 2; -} diff --git a/service/src/main/proto/NoiseTunnel.proto b/service/src/main/proto/NoiseTunnel.proto deleted file mode 100644 index 45695e20a..000000000 --- a/service/src/main/proto/NoiseTunnel.proto +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -syntax = "proto3"; - -option java_package = "org.whispersystems.textsecuregcm.grpc.net"; -option java_outer_classname = "NoiseTunnelProtos"; - -message HandshakeInit { - string user_agent = 1; - - // An Accept-Language as described in - // https://httpwg.org/specs/rfc9110.html#field.accept-language - string accept_language = 2; - - // A UUID serialized as 16 bytes (big end first). Must be unset (empty) for an - // unauthenticated handshake - bytes aci = 3; - - // The deviceId, 0 < deviceId < 128. Must be unset for an unauthenticated - // handshake - uint32 device_id = 4; - - // The first bytes of the application request byte stream, may contain less - // than a full request - bytes fast_open_request = 5; -} - -message HandshakeResponse { - enum Code { - UNSPECIFIED = 0; - - // The noise session may be used to send application layer requests - OK = 1; - - // The provided client static key did not match the registered public key - // for the provided aci/deviceId. - WRONG_PUBLIC_KEY = 2; - - // The client version is to old, it should be upgraded before retrying - DEPRECATED = 3; - } - - // The handshake outcome - Code code = 1; - - // Additional information about an error status, for debugging only - string error_details = 2; - - // An optional response to a fast_open_request provided in the HandshakeInit. - // Note that a response may not be present even if a fast_open_request was - // present. If so, the response will be returned in a later message. - bytes fast_open_response = 3; -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java index cfbf2cd94..7c3f78f7a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/DisconnectionRequestManagerTest.java @@ -20,8 +20,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.storage.Account; @@ -30,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Device; @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class DisconnectionRequestManagerTest { - private GrpcClientConnectionManager grpcClientConnectionManager; private DisconnectionRequestManager disconnectionRequestManager; @RegisterExtension @@ -38,10 +35,8 @@ class DisconnectionRequestManagerTest { @BeforeEach void setUp() { - grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(), - grpcClientConnectionManager, Runnable::run, mock(ScheduledExecutorService.class)); @@ -103,16 +98,8 @@ class DisconnectionRequestManagerTest { verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest(); verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest(); - verify(grpcClientConnectionManager, timeout(1_000)) - .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId)); - - verify(grpcClientConnectionManager, timeout(1_000)) - .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId)); disconnectionRequestManager.requestDisconnection(otherAccountIdentifier, List.of(otherDeviceId)); - - verify(grpcClientConnectionManager, timeout(1_000)) - .closeConnection(new AuthenticatedDevice(otherAccountIdentifier, otherDeviceId)); } @Test @@ -141,11 +128,5 @@ class DisconnectionRequestManagerTest { verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest(); verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest(); - - verify(grpcClientConnectionManager, timeout(1_000)) - .closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId)); - - verify(grpcClientConnectionManager, timeout(1_000)) - .closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java deleted file mode 100644 index 95bacd076..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/AbstractAuthenticationInterceptorTest.java +++ /dev/null @@ -1,77 +0,0 @@ -package org.whispersystems.textsecuregcm.auth.grpc; - -import static org.mockito.Mockito.mock; - -import io.grpc.ManagedChannel; -import io.grpc.Server; -import io.grpc.netty.NettyChannelBuilder; -import io.grpc.netty.NettyServerBuilder; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; -import java.io.IOException; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; -import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; -import org.signal.chat.rpc.RequestAttributesGrpc; -import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; - -abstract class AbstractAuthenticationInterceptorTest { - - private static DefaultEventLoopGroup eventLoopGroup; - - private GrpcClientConnectionManager grpcClientConnectionManager; - - private Server server; - private ManagedChannel managedChannel; - - @BeforeAll - static void setUpBeforeAll() { - eventLoopGroup = new DefaultEventLoopGroup(); - } - - @BeforeEach - void setUp() throws IOException { - final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server"); - - grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); - - // `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make - // sure that we're using local channels and addresses - server = NettyServerBuilder.forAddress(serverAddress) - .channelType(LocalServerChannel.class) - .bossEventLoopGroup(eventLoopGroup) - .workerEventLoopGroup(eventLoopGroup) - .intercept(getInterceptor()) - .addService(new RequestAttributesServiceImpl()) - .build() - .start(); - - managedChannel = NettyChannelBuilder.forAddress(serverAddress) - .channelType(LocalChannel.class) - .eventLoopGroup(eventLoopGroup) - .usePlaintext() - .build(); - } - - @AfterEach - void tearDown() { - managedChannel.shutdown(); - server.shutdown(); - } - - protected abstract AbstractAuthenticationInterceptor getInterceptor(); - - protected GrpcClientConnectionManager getClientConnectionManager() { - return grpcClientConnectionManager; - } - - protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() { - return RequestAttributesGrpc.newBlockingStub(managedChannel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java new file mode 100644 index 000000000..98e2c676a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicAuthCallCredentials.java @@ -0,0 +1,35 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.auth.grpc; + +import io.grpc.CallCredentials; +import io.grpc.Metadata; +import io.grpc.Status; +import java.util.concurrent.Executor; +import org.whispersystems.textsecuregcm.util.HeaderUtils; + +public class BasicAuthCallCredentials extends CallCredentials { + + private final String username; + private final String password; + + public BasicAuthCallCredentials(String username, String password) { + this.username = username; + this.password = password; + } + + @Override + public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, + final MetadataApplier applier) { + try { + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), + HeaderUtils.basicAuthHeader(username, password)); + applier.apply(headers); + } catch (Exception e) { + applier.fail(Status.UNAUTHENTICATED.withCause(e)); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java index a0d8c6688..189af0bd0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/ProhibitAuthenticationInterceptorTest.java @@ -1,44 +1,64 @@ package org.whispersystems.textsecuregcm.auth.grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.storage.Device; +import org.signal.chat.rpc.EchoRequest; +import org.signal.chat.rpc.EchoServiceGrpc; +import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; -import java.util.Optional; -import java.util.UUID; +import java.util.concurrent.TimeUnit; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest { +class ProhibitAuthenticationInterceptorTest { + private Server server; + private ManagedChannel channel; - @Override - protected AbstractAuthenticationInterceptor getInterceptor() { - return new ProhibitAuthenticationInterceptor(getClientConnectionManager()); + @BeforeEach + void setUp() throws Exception { + server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest") + .directExecutor() + .intercept(new ProhibitAuthenticationInterceptor()) + .addService(new EchoServiceImpl()) + .build() + .start(); + + channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest") + .directExecutor() + .build(); + } + + @AfterEach + void tearDown() throws Exception { + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(5, TimeUnit.SECONDS); + server.awaitTermination(5, TimeUnit.SECONDS); } @Test - void interceptCall() throws ChannelNotFoundException { - final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); + void hasAuth() { + final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc + .newBlockingStub(channel) + .withCallCredentials(new BasicAuthCallCredentials("test", "password")); - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); + final StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> client.echo(EchoRequest.getDefaultInstance())); + assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED); + } - final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); - assertTrue(response.getAccountIdentifier().isEmpty()); - assertEquals(0, response.getDeviceId()); - - final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); - - GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); - - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class); - - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice); + @Test + void noAuth() { + final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel); + assertDoesNotThrow(() -> client.echo(EchoRequest.getDefaultInstance())); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java index 442b38cc8..90bb86f79 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/RequireAuthenticationInterceptorTest.java @@ -1,44 +1,102 @@ package org.whispersystems.textsecuregcm.auth.grpc; +import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.dropwizard.auth.basic.BasicCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import java.time.Instant; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.storage.Device; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.RequestAttributesGrpc; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; import org.whispersystems.textsecuregcm.util.UUIDUtil; -class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest { +class RequireAuthenticationInterceptorTest { + private Server server; + private ManagedChannel channel; + private AccountAuthenticator authenticator; - @Override - protected AbstractAuthenticationInterceptor getInterceptor() { - return new RequireAuthenticationInterceptor(getClientConnectionManager()); + @BeforeEach + void setUp() throws Exception { + authenticator = mock(AccountAuthenticator.class); + server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest") + .directExecutor() + .intercept(new RequireAuthenticationInterceptor(authenticator)) + .addService(new RequestAttributesServiceImpl()) + .build() + .start(); + + channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest") + .directExecutor() + .build(); + } + + @AfterEach + void tearDown() throws Exception { + channel.shutdownNow(); + server.shutdownNow(); + channel.awaitTermination(5, TimeUnit.SECONDS); + server.awaitTermination(5, TimeUnit.SECONDS); } @Test - void interceptCall() throws ChannelNotFoundException { - final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager(); + void hasAuth() { + final UUID aci = UUID.randomUUID(); + final byte deviceId = 2; + when(authenticator.authenticate(eq(new BasicCredentials("test", "password")))) + .thenReturn(Optional.of( + new org.whispersystems.textsecuregcm.auth.AuthenticatedDevice(aci, deviceId, Instant.now()))); - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty()); + final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc + .newBlockingStub(channel) + .withCallCredentials(new BasicAuthCallCredentials("test", "password")); - GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice); + final GetAuthenticatedDeviceResponse authenticatedDevice = client.getAuthenticatedDevice( + GetAuthenticatedDeviceRequest.getDefaultInstance()); + assertEquals(authenticatedDevice.getDeviceId(), deviceId); + assertEquals(UUIDUtil.fromByteString(authenticatedDevice.getAccountIdentifier()), aci); + } - final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice)); + @Test + void badCredentials() { + when(authenticator.authenticate(any())).thenReturn(Optional.empty()); - final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice(); - assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier()); - assertEquals(authenticatedDevice.deviceId(), response.getDeviceId()); + final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc + .newBlockingStub(channel) + .withCallCredentials(new BasicAuthCallCredentials("test", "password")); - when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class); + final StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance())); + Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED); + } - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice); + @Test + void missingCredentials() { + when(authenticator.authenticate(any())).thenReturn(Optional.empty()); + + final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc.newBlockingStub(channel); + + final StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance())); + Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java deleted file mode 100644 index 35db29a77..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/ChannelShutdownInterceptorTest.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.grpc; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.Status; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; - -class ChannelShutdownInterceptorTest { - - private GrpcClientConnectionManager grpcClientConnectionManager; - private ChannelShutdownInterceptor channelShutdownInterceptor; - - private ServerCallHandler nextCallHandler; - - private static final Metadata HEADERS = new Metadata(); - - @BeforeEach - void setUp() { - grpcClientConnectionManager = mock(GrpcClientConnectionManager.class); - channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager); - - //noinspection unchecked - nextCallHandler = mock(ServerCallHandler.class); - - //noinspection unchecked - when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class)); - } - - @Test - void interceptCallComplete() { - @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); - - when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true); - - final ServerCall.Listener serverCallListener = - channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); - - serverCallListener.onComplete(); - - verify(grpcClientConnectionManager).handleServerCallStart(serverCall); - verify(grpcClientConnectionManager).handleServerCallComplete(serverCall); - verify(serverCall, never()).close(any(), any()); - } - - @Test - void interceptCallCancelled() { - @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); - - when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true); - - final ServerCall.Listener serverCallListener = - channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); - - serverCallListener.onCancel(); - - verify(grpcClientConnectionManager).handleServerCallStart(serverCall); - verify(grpcClientConnectionManager).handleServerCallComplete(serverCall); - verify(serverCall, never()).close(any(), any()); - } - - @Test - void interceptCallChannelClosing() { - @SuppressWarnings("unchecked") final ServerCall serverCall = mock(ServerCall.class); - - when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false); - - channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler); - - verify(grpcClientConnectionManager).handleServerCallStart(serverCall); - verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall); - verify(serverCall).close(eq(Status.UNAVAILABLE), any()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java new file mode 100644 index 000000000..e30dd9549 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/RequestAttributesInterceptorTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.MetadataUtils; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.signal.chat.rpc.GetRequestAttributesRequest; +import org.signal.chat.rpc.GetRequestAttributesResponse; +import org.signal.chat.rpc.RequestAttributesGrpc; + +public class RequestAttributesInterceptorTest { + + private static String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)"; + private Server server; + private AtomicBoolean removeUserAgent; + + @BeforeEach + void setUp() throws Exception { + removeUserAgent = new AtomicBoolean(false); + + server = NettyServerBuilder.forPort(0) + .directExecutor() + .intercept(new RequestAttributesInterceptor()) + // the grpc client always inserts a user-agent if we don't set one, so to test missing UAs we remove the header + // on the server-side + .intercept(new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, final ServerCallHandler next) { + if (removeUserAgent.get()) { + headers.removeAll(Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER)); + } + return next.startCall(call, headers); + } + }) + .addService(new RequestAttributesServiceImpl()) + .build() + .start(); + } + + @AfterEach + void tearDown() throws Exception { + server.shutdownNow(); + server.awaitTermination(1, TimeUnit.SECONDS); + } + + private static List handleInvalidAcceptLanguage() { + return List.of( + Arguments.argumentSet("Null Accept-Language header", Optional.empty()), + Arguments.argumentSet("Empty Accept-Language header", Optional.of("")), + Arguments.argumentSet("Invalid Accept-Language header", Optional.of("This is not a valid language preference list"))); + } + + @ParameterizedTest + @MethodSource + void handleInvalidAcceptLanguage(Optional acceptLanguageHeader) throws Exception { + final Metadata metadata = new Metadata(); + acceptLanguageHeader.ifPresent(h -> metadata + .put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), h)); + final GetRequestAttributesResponse response = getRequestAttributes(metadata); + assertEquals(response.getAcceptableLanguagesCount(), 0); + } + + @Test + void handleMissingUserAgent() throws InterruptedException { + removeUserAgent.set(true); + final GetRequestAttributesResponse response = getRequestAttributes(new Metadata()); + assertEquals("", response.getUserAgent()); + } + + @Test + void allAttributes() throws InterruptedException { + final Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), "ja,en;q=0.4"); + metadata.put(Metadata.Key.of("x-forwarded-for", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.3"); + final GetRequestAttributesResponse response = getRequestAttributes(metadata); + + assertTrue(response.getUserAgent().contains(USER_AGENT)); + assertEquals("127.0.0.3", response.getRemoteAddress()); + assertEquals(2, response.getAcceptableLanguagesCount()); + assertEquals("ja", response.getAcceptableLanguages(0)); + assertEquals("en;q=0.4", response.getAcceptableLanguages(1)); + } + + @Test + void useSocketAddrIfHeaderMissing() throws InterruptedException { + final GetRequestAttributesResponse response = getRequestAttributes(new Metadata()); + assertEquals("127.0.0.1", response.getRemoteAddress()); + } + + private GetRequestAttributesResponse getRequestAttributes(Metadata metadata) + throws InterruptedException { + final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", server.getPort()) + .directExecutor() + .usePlaintext() + .userAgent(USER_AGENT) + .build(); + try { + final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc + .newBlockingStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + return client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()); + } finally { + channel.shutdownNow(); + channel.awaitTermination(1, TimeUnit.SECONDS); + } + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java deleted file mode 100644 index 9cee9e7f1..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractLeakDetectionTest.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import io.netty.util.ResourceLeakDetector; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; - -public abstract class AbstractLeakDetectionTest { - - private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel; - - @BeforeAll - static void setLeakDetectionLevel() { - originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel(); - ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); - } - - @AfterAll - static void restoreLeakDetectionLevel() { - ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java deleted file mode 100644 index 78d328cd8..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java +++ /dev/null @@ -1,325 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; - -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.util.ReferenceCountUtil; -import java.net.InetAddress; -import java.net.UnknownHostException; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.concurrent.ThreadLocalRandom; -import javax.annotation.Nullable; -import javax.crypto.AEADBadTagException; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; - -abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { - - protected ECKeyPair serverKeyPair; - protected ClientPublicKeysManager clientPublicKeysManager; - - private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler; - - private EmbeddedChannel embeddedChannel; - - static final String USER_AGENT = "Test/User-Agent"; - static final String ACCEPT_LANGUAGE = "test-lang"; - static final InetAddress REMOTE_ADDRESS; - static { - try { - REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3}); - } catch (UnknownHostException e) { - throw new RuntimeException(e); - } - } - - private static class PongHandler extends ChannelInboundHandlerAdapter { - - @Override - public void channelRead(final ChannelHandlerContext ctx, final Object msg) { - try { - if (msg instanceof ByteBuf bb) { - if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) { - ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes())) - .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); - } else { - throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb))); - } - } else { - throw new IllegalArgumentException("Unexpected message type: " + msg); - } - } finally { - ReferenceCountUtil.release(msg); - } - } - } - - private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter { - - @Nullable - private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null; - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) { - handshakeCompleteEvent = noiseIdentityDeterminedEvent; - context.pipeline().addAfter(context.name(), null, new PongHandler()); - context.pipeline().remove(NoiseHandshakeCompleteHandler.class); - } else { - context.fireUserEventTriggered(event); - } - } - - @Nullable - public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() { - return handshakeCompleteEvent; - } - } - - @BeforeEach - void setUp() { - serverKeyPair = ECKeyPair.generate(); - noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler(); - clientPublicKeysManager = mock(ClientPublicKeysManager.class); - embeddedChannel = new EmbeddedChannel( - new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair), - noiseHandshakeCompleteHandler); - } - - @AfterEach - void tearDown() { - embeddedChannel.close(); - } - - protected EmbeddedChannel getEmbeddedChannel() { - return embeddedChannel; - } - - @Nullable - protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() { - return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent(); - } - - protected abstract CipherStatePair doHandshake() throws Throwable; - - /** - * Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no - * waiting messages in the channel, return null. - */ - byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException { - final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - if (responseFrame == null) { - return null; - } - final byte[] plaintext = new byte[responseFrame.readableBytes() - 16]; - final int read = clientCipherPair.getReceiver().decryptWithAd(null, - ByteBufUtil.getBytes(responseFrame), 0, - plaintext, 0, - responseFrame.readableBytes()); - assertEquals(read, plaintext.length); - return plaintext; - } - - - @Test - void handleInvalidInitialMessage() throws InterruptedException { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - final ByteBuf content = Unpooled.wrappedBuffer(contentBytes); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await(); - - assertFalse(writeFuture.isSuccess()); - assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause()); - assertEquals(0, content.refCnt()); - assertNull(getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException { - final ByteBuf[] frames = new ByteBuf[7]; - - for (int i = 0; i < frames.length; i++) { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - frames[i] = Unpooled.wrappedBuffer(contentBytes); - - embeddedChannel.writeOneInbound(frames[i]).await(); - } - - for (final ByteBuf frame : frames) { - assertEquals(0, frame.refCnt()); - } - - assertNull(getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleNonByteBufBinaryFrame() throws Throwable { - final byte[] contentBytes = new byte[17]; - ThreadLocalRandom.current().nextBytes(contentBytes); - - final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes)); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await(); - - assertFalse(writeFuture.isSuccess()); - assertInstanceOf(IllegalArgumentException.class, writeFuture.cause()); - assertEquals(0, message.refCnt()); - assertNull(getNoiseHandshakeCompleteEvent()); - - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } - - @Test - void channelRead() throws Throwable { - final CipherStatePair clientCipherStatePair = doHandshake(); - final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8); - final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()]; - clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length); - - final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext); - assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess()); - assertEquals(0, ciphertextFrame.refCnt()); - - final byte[] response = readNextPlaintext(clientCipherStatePair); - assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response); - } - - @Test - void channelReadBadCiphertext() throws Throwable { - doHandshake(); - final byte[] bogusCiphertext = new byte[32]; - io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext); - - final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext); - final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await(); - - assertEquals(0, ciphertextFrame.refCnt()); - assertFalse(readCiphertextFuture.isSuccess()); - assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause()); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } - - @Test - void channelReadUnexpectedMessageType() throws Throwable { - doHandshake(); - final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await(); - - assertFalse(readFuture.isSuccess()); - assertInstanceOf(IllegalArgumentException.class, readFuture.cause()); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - } - - @Test - void write() throws Throwable { - final CipherStatePair clientCipherStatePair = doHandshake(); - final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8); - final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext); - - final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); - assertTrue(writePlaintextFuture.await().isSuccess()); - assertEquals(0, plaintextBuffer.refCnt()); - - final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - assertNotNull(ciphertextFrame); - assertTrue(embeddedChannel.outboundMessages().isEmpty()); - - final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame); - ciphertextFrame.release(); - - final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()]; - clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length); - - assertArrayEquals(plaintext, decryptedPlaintext); - } - - @Test - void writeUnexpectedMessageType() throws Throwable { - doHandshake(); - final Object unexpectedMessaged = new Object(); - - final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged); - assertTrue(writeFuture.await().isSuccess()); - - assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll()); - assertTrue(embeddedChannel.outboundMessages().isEmpty()); - } - - @ParameterizedTest - @ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5}) - void writeHugeOutboundMessage(final int plaintextLength) throws Throwable { - final CipherStatePair clientCipherStatePair = doHandshake(); - final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength); - final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length)); - - final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); - assertTrue(writePlaintextFuture.isSuccess()); - - final byte[] decryptedPlaintext = new byte[plaintextLength]; - int plaintextOffset = 0; - ByteBuf ciphertextFrame; - while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) { - assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN); - final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame); - ciphertextFrame.release(); - plaintextOffset += clientCipherStatePair.getReceiver() - .decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length); - } - assertArrayEquals(plaintext, decryptedPlaintext); - assertEquals(0, plaintextBuffer.refCnt()); - - } - - @Test - public void writeHugeInboundMessage() throws Throwable { - doHandshake(); - final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1); - embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big)); - assertThrows(NoiseException.class, embeddedChannel::checkException); - } - - @Test - public void channelAttributes() throws Throwable { - doHandshake(); - final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent(); - assertEquals(REMOTE_ADDRESS, event.remoteAddress()); - assertEquals(USER_AGENT, event.userAgent()); - assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage()); - } - - protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() { - return NoiseTunnelProtos.HandshakeInit.newBuilder() - .setUserAgent(USER_AGENT) - .setAcceptLanguage(ACCEPT_LANGUAGE); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java deleted file mode 100644 index 4d1eeda6b..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseTunnelServerIntegrationTest.java +++ /dev/null @@ -1,451 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import com.google.protobuf.ByteString; -import io.grpc.ManagedChannel; -import io.grpc.ServerBuilder; -import io.grpc.Status; -import io.grpc.netty.NettyChannelBuilder; -import io.grpc.stub.StreamObserver; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.handler.codec.haproxy.HAProxyCommand; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; -import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.function.Supplier; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.signal.chat.rpc.EchoRequest; -import org.signal.chat.rpc.EchoResponse; -import org.signal.chat.rpc.EchoServiceGrpc; -import org.signal.chat.rpc.GetAuthenticatedDeviceRequest; -import org.signal.chat.rpc.GetAuthenticatedDeviceResponse; -import org.signal.chat.rpc.GetRequestAttributesRequest; -import org.signal.chat.rpc.RequestAttributesGrpc; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; -import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; -import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor; -import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; -import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; -import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor; -import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl; -import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest { - - private static NioEventLoopGroup nioEventLoopGroup; - private static DefaultEventLoopGroup defaultEventLoopGroup; - private static ExecutorService delegatedTaskExecutor; - private static ExecutorService serverCallExecutor; - - private GrpcClientConnectionManager grpcClientConnectionManager; - private ClientPublicKeysManager clientPublicKeysManager; - - private ECKeyPair serverKeyPair; - private ECKeyPair clientKeyPair; - - private ManagedLocalGrpcServer authenticatedGrpcServer; - private ManagedLocalGrpcServer anonymousGrpcServer; - - private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID(); - private static final byte DEVICE_ID = Device.PRIMARY_ID; - - public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16); - - @BeforeAll - static void setUpBeforeAll() { - nioEventLoopGroup = new NioEventLoopGroup(); - defaultEventLoopGroup = new DefaultEventLoopGroup(); - delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor(); - serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor(); - } - - @BeforeEach - void setUp() throws Exception { - - clientKeyPair = ECKeyPair.generate(); - serverKeyPair = ECKeyPair.generate(); - - grpcClientConnectionManager = new GrpcClientConnectionManager(); - - clientPublicKeysManager = mock(ClientPublicKeysManager.class); - when(clientPublicKeysManager.findPublicKey(any(), anyByte())) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated"); - final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous"); - - authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) { - @Override - protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder - .executor(serverCallExecutor) - .addService(new RequestAttributesServiceImpl()) - .addService(new EchoServiceImpl()) - .intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager)) - .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) - .intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager)); - } - }; - - authenticatedGrpcServer.start(); - - anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) { - @Override - protected void configureServer(final ServerBuilder serverBuilder) { - serverBuilder - .executor(serverCallExecutor) - .addService(new RequestAttributesServiceImpl()) - .intercept(new RequestAttributesInterceptor(grpcClientConnectionManager)) - .intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager)); - } - }; - - anonymousGrpcServer.start(); - this.start( - nioEventLoopGroup, - delegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, anonymousGrpcServerAddress, - RECOGNIZED_PROXY_SECRET); - } - - - protected abstract void start( - final NioEventLoopGroup eventLoopGroup, - final Executor delegatedTaskExecutor, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair serverKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final String recognizedProxySecret) throws Exception; - protected abstract void stop() throws Exception; - protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey); - - public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason) - throws ExecutionException, InterruptedException, TimeoutException { - final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS); - assertEquals(reason, result.closeReason()); - } - - @AfterEach - void tearDown() throws Exception { - authenticatedGrpcServer.stop(); - anonymousGrpcServer.stop(); - this.stop(); - } - - @AfterAll - static void tearDownAfterAll() throws InterruptedException { - nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); - defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await(); - - delegatedTaskExecutor.shutdown(); - //noinspection ResultOfMethodCallIgnored - delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS); - - serverCallExecutor.shutdown(); - //noinspection ResultOfMethodCallIgnored - serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException { - try (final NoiseTunnelClient client = authenticated() - .setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage)) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException { - - // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseTunnelClient client = authenticated() - .setServerPublicKey(ECKeyPair.generate().getPublicKey()) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - } - - @Test - void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException { - - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey()))); - - try (final NoiseTunnelClient client = authenticated().build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertEquals( - NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY, - client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode()); - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - } - - @Test - void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException { - when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID)) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - try (final NoiseTunnelClient client = authenticated().build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertEquals( - NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY, - client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode()); - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - - } - - @Test - void clientNormalClosure() throws InterruptedException { - final NoiseTunnelClient client = anonymous().build(); - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertTrue(response.getAccountIdentifier().isEmpty()); - assertEquals(0, response.getDeviceId()); - client.close(); - - // When we gracefully close the tunnel client, we should send an OK close frame - final CloseFrameEvent closeFrame = client.closeFrameFuture().join(); - assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator()); - assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason()); - } finally { - channel.shutdown(); - } - } - - @Test - void connectAnonymous() throws InterruptedException { - try (final NoiseTunnelClient client = anonymous().build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertTrue(response.getAccountIdentifier().isEmpty()); - assertEquals(0, response.getDeviceId()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException { - - // Try to verify the server's public key with something other than the key with which it was signed - try (final NoiseTunnelClient client = anonymous() - .setServerPublicKey(ECKeyPair.generate().getPublicKey()) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - - } - - protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) { - return NettyChannelBuilder.forAddress(localAddress) - .channelType(LocalChannel.class) - .eventLoopGroup(defaultEventLoopGroup) - .usePlaintext() - .build(); - } - - - @Test - void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException { - - try (final NoiseTunnelClient client = authenticated().build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - - grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); - final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS); - assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason()); - assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException { - try (final NoiseTunnelClient client = authenticated().build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build()); - - assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier()); - assertEquals(DEVICE_ID, response.getDeviceId()); - - final CountDownLatch responseCountDownLatch = new CountDownLatch(1); - - // Start an open-ended server call and leave it in a non-complete state - final StreamObserver echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream( - new StreamObserver<>() { - @Override - public void onNext(final EchoResponse echoResponse) { - responseCountDownLatch.countDown(); - } - - @Override - public void onError(final Throwable throwable) { - } - - @Override - public void onCompleted() { - } - }); - - // Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before - // the request even starts. Make sure we've done at least one request/response pair to ensure that the call has - // truly started before requesting connection closure. - echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()); - assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS)); - - grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID)); - try { - client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS); - fail("Channel should not close until active requests have finished"); - } catch (TimeoutException e) { - } - - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel) - .echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build())); - - // Complete the open-ended server call - echoRequestStreamObserver.onCompleted(); - - final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS); - assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator()); - assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason()); - } finally { - channel.shutdown(); - } - } - } - - protected NoiseTunnelClient.Builder anonymous() { - return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey()); - } - - protected NoiseTunnelClient.Builder authenticated() { - return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey()) - .setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID); - } - - private static Supplier proxyMessageSupplier(boolean includeProxyMesage) { - return includeProxyMesage - ? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, - "10.0.0.1", "10.0.0.2", 12345, 443) - : null; - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java deleted file mode 100644 index 654ac5d72..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/GrpcClientConnectionManagerTest.java +++ /dev/null @@ -1,227 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.google.common.net.InetAddresses; -import io.netty.bootstrap.Bootstrap; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.DefaultEventLoopGroup; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; -import java.net.InetAddress; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Optional; -import java.util.UUID; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException; -import org.whispersystems.textsecuregcm.grpc.RequestAttributes; -import org.whispersystems.textsecuregcm.storage.Device; - -class GrpcClientConnectionManagerTest { - - private static EventLoopGroup eventLoopGroup; - - private LocalChannel localChannel; - private LocalChannel remoteChannel; - - private LocalServerChannel localServerChannel; - - private GrpcClientConnectionManager grpcClientConnectionManager; - - @BeforeAll - static void setUpBeforeAll() { - eventLoopGroup = new DefaultEventLoopGroup(); - } - - @BeforeEach - void setUp() throws InterruptedException { - eventLoopGroup = new DefaultEventLoopGroup(); - - grpcClientConnectionManager = new GrpcClientConnectionManager(); - - // We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial - // local server to which we can open trivial local connections - localServerChannel = (LocalServerChannel) new ServerBootstrap() - .group(eventLoopGroup) - .channel(LocalServerChannel.class) - .childHandler(new ChannelInitializer<>() { - @Override - protected void initChannel(final Channel channel) { - } - }) - .bind(new LocalAddress("test-server")) - .await() - .channel(); - - final Bootstrap clientBootstrap = new Bootstrap() - .group(eventLoopGroup) - .channel(LocalChannel.class) - .handler(new ChannelInitializer<>() { - @Override - protected void initChannel(final Channel ch) { - } - }); - - localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel(); - remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel(); - } - - @AfterEach - void tearDown() throws InterruptedException { - localChannel.close().await(); - remoteChannel.close().await(); - localServerChannel.close().await(); - } - - @AfterAll - static void tearDownAfterAll() throws InterruptedException { - eventLoopGroup.shutdownGracefully().await(); - } - - @ParameterizedTest - @MethodSource - void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional maybeAuthenticatedDevice) { - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice); - - assertEquals(maybeAuthenticatedDevice, - grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel)); - } - - private static List> getAuthenticatedDevice() { - return List.of( - Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)), - Optional.empty() - ); - } - - @Test - void getRequestAttributes() { - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - - assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel)); - - final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null); - remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes); - - assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel)); - } - - @Test - void closeConnection() throws InterruptedException, ChannelNotFoundException { - final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); - - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); - - assertTrue(remoteChannel.isOpen()); - - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - assertEquals(List.of(remoteChannel), - grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); - - remoteChannel.close().await(); - - assertThrows(ChannelNotFoundException.class, - () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - - assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); - } - - @ParameterizedTest - @MethodSource - void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress, - final String userAgentHeader, - final String acceptLanguageHeader, - final RequestAttributes expectedRequestAttributes) { - - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - - GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel, - preferredRemoteAddress, - userAgentHeader, - acceptLanguageHeader); - - assertEquals(expectedRequestAttributes, - embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get()); - } - - private static List handleHandshakeInitiatedRequestAttributes() { - final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1"); - - return List.of( - Arguments.argumentSet("Null User-Agent and Accept-Language headers", - preferredRemoteAddress, null, null, - new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())), - - Arguments.argumentSet("Recognized User-Agent and null Accept-Language header", - preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null, - new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())), - - Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header", - preferredRemoteAddress, "Not a valid user-agent string", null, - new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())), - - Arguments.argumentSet("Null User-Agent and parsable Accept-Language header", - preferredRemoteAddress, null, "ja,en;q=0.4", - new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))), - - Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header", - preferredRemoteAddress, null, "This is not a valid language preference list", - new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())) - ); - } - - @Test - void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException { - final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID); - - assertThrows(ChannelNotFoundException.class, - () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - - assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); - - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice)); - - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); - - remoteChannel.close().await(); - - assertThrows(ChannelNotFoundException.class, - () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - - assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice)); - } - - @Test - void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException { - assertThrows(ChannelNotFoundException.class, - () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - - grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty()); - - assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - - remoteChannel.close().await(); - - assertThrows(ChannelNotFoundException.class, - () -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress())); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java deleted file mode 100644 index de26f3c82..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/HAProxyMessageHandlerTest.java +++ /dev/null @@ -1,62 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.haproxy.HAProxyCommand; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; -import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -class HAProxyMessageHandlerTest { - - private EmbeddedChannel embeddedChannel; - - @BeforeEach - void setUp() { - embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler()); - } - - @Test - void handleHAProxyMessage() throws InterruptedException { - final HAProxyMessage haProxyMessage = new HAProxyMessage( - HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, - "10.0.0.1", "10.0.0.2", 12345, 443); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage); - embeddedChannel.flushInbound(); - - writeFuture.await(); - - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - assertEquals(0, haProxyMessage.refCnt()); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - } - - @Test - void handleNonHAProxyMessage() throws InterruptedException { - final byte[] bytes = new byte[32]; - ThreadLocalRandom.current().nextBytes(bytes); - - final ByteBuf message = Unpooled.wrappedBuffer(bytes); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message); - embeddedChannel.flushInbound(); - - writeFuture.await(); - - assertEquals(1, embeddedChannel.inboundMessages().size()); - assertEquals(message, embeddedChannel.inboundMessages().poll()); - assertEquals(1, message.refCnt()); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java deleted file mode 100644 index 0a775ccbf..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAnonymousHandlerTest.java +++ /dev/null @@ -1,116 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.google.protobuf.ByteString; -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.embedded.EmbeddedChannel; -import java.util.Optional; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.junit.jupiter.api.Test; - -class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest { - - @Override - protected CipherStatePair doHandshake() throws Exception { - return doHandshake(baseHandshakeInit().build().toByteArray()); - } - - private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - final HandshakeState clientHandshakeState = - new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR); - - clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0); - clientHandshakeState.start(); - - // Send initiator handshake message - - // 32 byte key, request payload, 16 byte AEAD tag - final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16; - final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength]; - assertEquals( - initiateHandshakeMessageLength, - clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length)); - final NoiseHandshakeInit message = new NoiseHandshakeInit( - REMOTE_ADDRESS, - HandshakePattern.NK, - Unpooled.wrappedBuffer(initiateHandshakeMessage)); - assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess()); - assertEquals(0, message.refCnt()); - - embeddedChannel.runPendingTasks(); - - // Read responder handshake message - assertFalse(embeddedChannel.outboundMessages().isEmpty()); - final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - assertNotNull(responderHandshakeFrame); - final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame); - - final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder() - .setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK) - .build(); - - // ephemeral key, payload, AEAD tag - assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length); - - final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()]; - assertEquals(expectedHandshakeResponse.getSerializedSize(), - clientHandshakeState.readMessage( - responderHandshakeBytes, 0, responderHandshakeBytes.length, - handshakeResponsePlaintext, 0)); - - assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext)); - - final byte[] serverPublicKey = new byte[32]; - clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); - assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes()); - - return clientHandshakeState.split(); - } - - @Test - void handleCompleteHandshakeWithRequest() throws Exception { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final byte[] handshakePlaintext = baseHandshakeInit() - .setFastOpenRequest(ByteString.copyFromUtf8("ping")).build() - .toByteArray(); - - final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext); - final byte[] response = readNextPlaintext(cipherStatePair); - assertArrayEquals(response, "pong".getBytes()); - - assertEquals( - new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), - getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake()); - assertNull(readNextPlaintext(cipherStatePair)); - - assertEquals( - new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), - getNoiseHandshakeCompleteEvent()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java deleted file mode 100644 index 225800963..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseAuthenticatedHandlerTest.java +++ /dev/null @@ -1,338 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -import com.google.protobuf.ByteString; -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.HandshakeState; -import com.southernstorm.noise.protocol.Noise; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.util.internal.EmptyArrays; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.security.NoSuchAlgorithmException; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ThreadLocalRandom; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.junit.jupiter.api.Test; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest { - - private final ECKeyPair clientKeyPair = ECKeyPair.generate(); - - @Override - protected CipherStatePair doHandshake() throws Throwable { - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - return doHandshake(identityPayload(accountIdentifier, deviceId)); - } - - @Test - void handleCompleteHandshakeNoInitialRequest() throws Throwable { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId)))); - - assertEquals( - new NoiseIdentityDeterminedEvent( - Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)), - REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), - getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleCompleteHandshakeWithInitialRequest() throws Throwable { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey()))); - - final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId) - .setFastOpenRequest(ByteString.copyFromUtf8("ping")) - .build() - .toByteArray(); - - final byte[] response = readNextPlaintext(doHandshake(handshakeInit)); - assertEquals(4, response.length); - assertEquals("pong", new String(response)); - - assertEquals( - new NoiseIdentityDeterminedEvent( - Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)), - REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE), - getNoiseHandshakeCompleteEvent()); - } - - @Test - void handleCompleteHandshakeMissingIdentityInformation() { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES)); - - verifyNoInteractions(clientPublicKeysManager); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakeMalformedIdentityInformation() { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - // no deviceId byte - byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID()); - assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload)); - - verifyNoInteractions(clientPublicKeysManager); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakeUnrecognizedDevice() throws Throwable { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - - doHandshake( - identityPayload(accountIdentifier, deviceId), - NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - - assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class), - "Noise stream handler should not be added to pipeline after failed handshake"); - } - - @Test - void handleCompleteHandshakePublicKeyMismatch() throws Throwable { - - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey()))); - - doHandshake( - identityPayload(accountIdentifier, deviceId), - NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY); - - assertNull(getNoiseHandshakeCompleteEvent()); - - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class), - "Handshake handler should not remove self from pipeline after failed handshake"); - } - - @Test - void handleInvalidExtraWrites() - throws NoSuchAlgorithmException, ShortBufferException, InterruptedException { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class)); - - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - - final HandshakeState clientHandshakeState = clientHandshakeState(); - - final CompletableFuture> findPublicKeyFuture = new CompletableFuture<>(); - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture); - - final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit( - REMOTE_ADDRESS, - HandshakePattern.IK, - Unpooled.wrappedBuffer( - initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId)))); - assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess()); - - // While waiting for the public key, send another message - final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await(); - assertInstanceOf(IllegalArgumentException.class, f.exceptionNow()); - - findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey())); - embeddedChannel.runPendingTasks(); - } - - @Test - public void handleOversizeHandshakeMessage() { - final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1); - ByteBuffer.wrap(big) - .put(UUIDUtil.toBytes(UUID.randomUUID())) - .put((byte) 0x01); - assertThrows(NoiseHandshakeException.class, () -> doHandshake(big)); - } - - @Test - public void handleKeyLookupError() { - final UUID accountIdentifier = UUID.randomUUID(); - final byte deviceId = randomDeviceId(); - when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)) - .thenReturn(CompletableFuture.failedFuture(new IOException())); - assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId))); - } - - private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException { - final HandshakeState clientHandshakeState = - new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); - - clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0); - clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0); - clientHandshakeState.start(); - return clientHandshakeState; - } - - private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload) - throws ShortBufferException { - // Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag - final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16]; - int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length); - assertEquals(written, initiatorMessageBytes.length); - return initiatorMessageBytes; - } - - private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message) - throws ShortBufferException, BadPaddingException { - - // 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload - final int expectedResponsePayloadLength = message.length - 32 - 16; - final byte[] responsePayload = new byte[expectedResponsePayloadLength]; - final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0); - assertEquals(expectedResponsePayloadLength, responsePayloadLength); - return responsePayload; - } - - private CipherStatePair doHandshake(final byte[] payload) throws Throwable { - return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK); - } - - private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable { - final EmbeddedChannel embeddedChannel = getEmbeddedChannel(); - - final HandshakeState clientHandshakeState = clientHandshakeState(); - final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload); - - final NoiseHandshakeInit initMessage = new NoiseHandshakeInit( - REMOTE_ADDRESS, - HandshakePattern.IK, - Unpooled.wrappedBuffer(initiatorMessage)); - final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await(); - assertEquals(0, initMessage.refCnt()); - if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) { - throw await.cause(); - } - - // The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the - // result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same - // thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback - // and issue a "handshake complete" event. - embeddedChannel.runPendingTasks(); - - // rethrow if running the task caused an error, and the caller isn't expecting an error - if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) { - embeddedChannel.checkException(); - } - - assertFalse(embeddedChannel.outboundMessages().isEmpty()); - - final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll(); - assertNotNull(handshakeResponseFrame); - final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame); - - final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder() - .setCode(expectedStatus) - .build(); - - final byte[] actualHandshakeResponsePlaintext = - readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes); - - assertEquals( - expectedHandshakeResponsePlaintext, - NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext)); - - final byte[] serverPublicKey = new byte[32]; - clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0); - assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes()); - - return clientHandshakeState.split(); - } - - private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) { - return baseHandshakeInit() - .setAci(UUIDUtil.toByteString(accountIdentifier)) - .setDeviceId(deviceId); - } - - private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) { - return identifiedHandshakeInit(accountIdentifier, deviceId) - .build() - .toByteArray(); - } - - private static byte randomDeviceId() { - return (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java deleted file mode 100644 index 13d26813c..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandshakeHelperTest.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNoException; - -import com.southernstorm.noise.protocol.HandshakeState; -import io.netty.buffer.ByteBuf; -import java.nio.charset.StandardCharsets; -import javax.crypto.ShortBufferException; -import io.netty.buffer.ByteBufUtil; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper; - - -public class NoiseHandshakeHelperTest { - - @ParameterizedTest - @EnumSource(HandshakePattern.class) - void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { - doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8)); - } - - @ParameterizedTest - @EnumSource(HandshakePattern.class) - void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { - doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]); - } - - @ParameterizedTest - @EnumSource(HandshakePattern.class) - void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException { - doHandshake(pattern, new byte[0], new byte[0]); - } - - void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException { - final ECKeyPair serverKeyPair = ECKeyPair.generate(); - final ECKeyPair clientKeyPair = ECKeyPair.generate(); - - NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair); - NoiseClientHandshakeHelper clientHelper = switch (pattern) { - case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair); - case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey()); - }; - - final byte[] initiate = clientHelper.write(requestPayload); - final ByteBuf actualRequestPayload = serverHelper.read(initiate); - assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload); - - assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE); - - final byte[] respond = serverHelper.write(responsePayload); - byte[] actualResponsePayload = clientHelper.read(respond); - assertThat(actualResponsePayload).isEqualTo(responsePayload); - - assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT); - assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split()); - assertThatNoException().isThrownBy(() -> clientHelper.split()); - } - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java deleted file mode 100644 index d7d3751ac..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/ProxyProtocolDetectionHandlerTest.java +++ /dev/null @@ -1,108 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFuture; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import java.util.HexFormat; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -class ProxyProtocolDetectionHandlerTest { - - private EmbeddedChannel embeddedChannel; - - private static final byte[] PROXY_V2_MESSAGE_BYTES = - HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb"); - - @BeforeEach - void setUp() { - embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler()); - } - - @Test - void singlePacketProxyMessage() throws InterruptedException { - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES)); - embeddedChannel.flushInbound(); - - writeFuture.await(); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - assertEquals(1, embeddedChannel.inboundMessages().size()); - assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll()); - } - - @Test - void multiPacketProxyMessage() throws InterruptedException { - final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound( - Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0, - ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)); - - final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound( - Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1, - PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1))); - - embeddedChannel.flushInbound(); - - firstWriteFuture.await(); - secondWriteFuture.await(); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - assertEquals(1, embeddedChannel.inboundMessages().size()); - assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll()); - } - - @Test - void singlePacketNonProxyMessage() throws InterruptedException { - final byte[] nonProxyProtocolMessage = new byte[32]; - ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage); - - final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage)); - embeddedChannel.flushInbound(); - - writeFuture.await(); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - assertEquals(1, embeddedChannel.inboundMessages().size()); - - final Object inboundMessage = embeddedChannel.inboundMessages().poll(); - - assertInstanceOf(ByteBuf.class, inboundMessage); - assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage)); - } - - @Test - void multiPacketNonProxyMessage() throws InterruptedException { - final byte[] nonProxyProtocolMessage = new byte[32]; - ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage); - - final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound( - Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0, - ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)); - - final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound( - Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1, - nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1))); - - embeddedChannel.flushInbound(); - - firstWriteFuture.await(); - secondWriteFuture.await(); - - assertTrue(embeddedChannel.pipeline().toMap().isEmpty()); - assertEquals(1, embeddedChannel.inboundMessages().size()); - - final Object inboundMessage = embeddedChannel.inboundMessages().poll(); - - assertInstanceOf(ByteBuf.class, inboundMessage); - assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage)); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java deleted file mode 100644 index 0b6e2c0ac..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ClientErrorHandler.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ClientErrorHandler extends ChannelInboundHandlerAdapter { - private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class); - - @Override - public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) { - log.error("Caught inbound error in client; closing connection", cause); - context.channel().close(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java deleted file mode 100644 index 13c8b9924..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/CloseFrameEvent.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; - -public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) { - - public enum CloseReason { - OK, - SERVER_CLOSED, - NOISE_ERROR, - NOISE_HANDSHAKE_ERROR, - INTERNAL_SERVER_ERROR, - UNKNOWN - } - - public enum CloseInitiator { - SERVER, - CLIENT - } - - public static CloseFrameEvent fromWebsocketCloseFrame( - CloseWebSocketFrame closeWebSocketFrame, - CloseInitiator closeInitiator) { - final CloseReason code = switch (closeWebSocketFrame.statusCode()) { - case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR; - case 4002 -> CloseReason.NOISE_ERROR; - case 1011 -> CloseReason.INTERNAL_SERVER_ERROR; - case 1012 -> CloseReason.SERVER_CLOSED; - case 1000 -> CloseReason.OK; - default -> CloseReason.UNKNOWN; - }; - return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText()); - } - - public static CloseFrameEvent fromNoiseDirectCloseFrame( - NoiseDirectProtos.CloseReason noiseDirectCloseReason, - CloseInitiator closeInitiator) { - final CloseReason code = switch (noiseDirectCloseReason.getCode()) { - case OK -> CloseReason.OK; - case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR; - case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR; - case UNAVAILABLE -> CloseReason.SERVER_CLOSED; - case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR; - case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN; - }; - return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java deleted file mode 100644 index c735902b6..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/EstablishRemoteConnectionHandler.java +++ /dev/null @@ -1,120 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.util.ReferenceCountUtil; -import java.net.SocketAddress; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; -import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler; - -/** - * Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote - * gRPC server. - *

- * This handler waits until the first gRPC client message is ready and then establishes a connection with the remote - * gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote - * connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the - * handshake is finished. - */ -class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter { - - private final List remoteHandlerStack; - private final NoiseTunnelProtos.HandshakeInit handshakeInit; - - private final SocketAddress remoteServerAddress; - // If provided, will be sent with the payload in the noise handshake - - private final List pendingReads = new ArrayList<>(); - - private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake"; - - EstablishRemoteConnectionHandler( - final List remoteHandlerStack, - final SocketAddress remoteServerAddress, - final NoiseTunnelProtos.HandshakeInit handshakeInit) { - this.remoteHandlerStack = remoteHandlerStack; - this.handshakeInit = handshakeInit; - this.remoteServerAddress = remoteServerAddress; - } - - @Override - public void handlerAdded(final ChannelHandlerContext localContext) { - new Bootstrap() - .channel(NioSocketChannel.class) - .group(localContext.channel().eventLoop()) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(final SocketChannel channel) throws Exception { - - for (ChannelHandler handler : remoteHandlerStack) { - channel.pipeline().addLast(handler); - } - channel.pipeline() - .addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event) - throws Exception { - switch (event) { - case ReadyForNoiseHandshakeEvent ignored -> - remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray())) - .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); - case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> { - remoteContext.pipeline() - .replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel())); - localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel())); - - // If there was a payload response on the handshake, write it back to our gRPC client - if (!handshakeResponse.getFastOpenResponse().isEmpty()) { - localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse - .getFastOpenResponse() - .asReadOnlyByteBuffer())); - } - - // Forward any messages we got from our gRPC client, now will be proxied to the remote context - pendingReads.forEach(localContext::fireChannelRead); - pendingReads.clear(); - localContext.pipeline().remove(EstablishRemoteConnectionHandler.this); - } - default -> { - } - } - super.userEventTriggered(remoteContext, event); - } - }) - .addLast(new ClientErrorHandler()); - } - }) - .connect(remoteServerAddress) - .addListener((ChannelFutureListener) future -> { - if (future.isSuccess()) { - // Close the local connection if the remote channel closes and vice versa - future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close()); - localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close()); - } else { - localContext.close(); - } - }); - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - pendingReads.add(message); - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - pendingReads.forEach(ReferenceCountUtil::release); - pendingReads.clear(); - } - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java deleted file mode 100644 index dae7c51d6..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/FastOpenRequestBufferedEvent.java +++ /dev/null @@ -1,9 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.buffer.ByteBuf; - -public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java deleted file mode 100644 index 4237d889f..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/HAProxyMessageSender.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import java.util.function.Supplier; - -class HAProxyMessageSender extends ChannelInboundHandlerAdapter { - - private final Supplier messageSupplier; - - HAProxyMessageSender(final Supplier messageSupplier) { - this.messageSupplier = messageSupplier; - } - - @Override - public void handlerAdded(final ChannelHandlerContext context) { - if (context.channel().isActive()) { - context.writeAndFlush(messageSupplier.get()); - } - } - - @Override - public void channelActive(final ChannelHandlerContext context) { - context.writeAndFlush(messageSupplier.get()); - context.fireChannelActive(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java deleted file mode 100644 index e476ca887..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/Http2Buffering.java +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandler; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; -import io.netty.util.ReferenceCountUtil; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HexFormat; -import java.util.List; -import java.util.stream.Stream; - -/** - * The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast - * open" optimization where the client can send a request along with the noise handshake. There's no direct way to - * extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an - * inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request. - *

- * Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a - * {@link FastOpenRequestBufferedEvent} - */ -public class Http2Buffering { - - /** - * Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request - */ - public static ChannelInboundHandler handler() { - return new Http2PrefaceHandler(); - } - - private Http2Buffering() { - } - - private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter { - - // https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5 - private static final byte[] HTTP2_PREFACE = - HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a"); - private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length); - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - if (message instanceof ByteBuf bb) { - bb.readBytes(read); - if (read.readableBytes() < HTTP2_PREFACE.length) { - // Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more. - return; - } - if (!Arrays.equals(read.array(), HTTP2_PREFACE)) { - throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface"); - } - context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder()); - context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder()); - context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler()); - context.fireChannelRead(bb); - } else { - throw new IllegalStateException("Unexpected message: " + message); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - ReferenceCountUtil.release(read); - } - } - - - private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) { - - private static final byte FLAG_END_STREAM = 0x01; - - enum FrameType { - SETTINGS, - HEADERS, - DATA, - WINDOW_UPDATE, - OTHER; - - static FrameType fromSerializedType(final byte type) { - return switch (type) { - case 0x00 -> Http2Frame.FrameType.DATA; - case 0x01 -> Http2Frame.FrameType.HEADERS; - case 0x04 -> Http2Frame.FrameType.SETTINGS; - case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE; - default -> Http2Frame.FrameType.OTHER; - }; - } - } - } - - /** - * Emit ByteBuf of entire HTTP/2 frame - */ - private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder { - - public Http2LengthFieldFrameDecoder() { - // Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload - super(16 * 1024 * 1024, 0, 3, 6, 0); - } - } - - /** - * Parse the serialized Http/2 frames into {@link Http2Frame} objects - */ - private static class Http2FrameDecoder extends ByteToMessageDecoder { - - @Override - protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List out) throws Exception { - // https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1 - final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3)); - final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4)); - out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream)); - } - - boolean endStream(Http2Frame.FrameType frameType, byte flags) { - // A gRPC request are packed into HTTP/2 frames like: - // HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1) - // - // Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set - // we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about - // is having at least one complete request. In total, we can expect something like: - // HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1) - // - // The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack - // bit set, but HTTP/2 allows the client to send frames before getting the ACK. - if (frameType == Http2Frame.FrameType.DATA) { - return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM; - } - - // In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete - // (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a - // DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we - // might as well pack that in. - return frameType == Http2Frame.FrameType.WINDOW_UPDATE; - } - } - - /** - * Collect HTTP/2 frames until we get an entire "request" to send - */ - private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter { - - final List pendingFrames = new ArrayList<>(); - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) { - if (message instanceof Http2Frame http2Frame) { - if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) { - throw new IllegalStateException( - "HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type); - } - pendingFrames.add(http2Frame); - if (http2Frame.endStream) { - // We have a whole "request", emit the first request event and remove the http2 buffering handlers - final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat( - Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)), - pendingFrames.stream().map(Http2Frame::bytes)) - .toArray(ByteBuf[]::new)); - pendingFrames.clear(); - context.pipeline().remove(Http2LengthFieldFrameDecoder.class); - context.pipeline().remove(Http2FrameDecoder.class); - context.pipeline().remove(this); - context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request)); - } - } else { - throw new IllegalStateException("Unexpected message: " + message); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes())); - pendingFrames.clear(); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java deleted file mode 100644 index 66ef45c9f..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeCompleteEvent.java +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; - -import java.util.Optional; - -/** - * A netty user event that indicates that the noise handshake finished successfully. - * - * @param fastResponse A response if the client included a request to send in the initiate handshake message payload and - * the server included a payload in the handshake response. - */ -public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java deleted file mode 100644 index c45d8639c..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHandler.java +++ /dev/null @@ -1,62 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import com.google.protobuf.InvalidProtocolBufferException; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import java.util.Optional; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; -import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; -import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage; - -public class NoiseClientHandshakeHandler extends ChannelDuplexHandler { - - private final NoiseClientHandshakeHelper handshakeHelper; - - public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) { - this.handshakeHelper = handshakeHelper; - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof ByteBuf plaintextHandshakePayload) { - final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload, - plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(), - false); - final byte[] handshakeMessage = handshakeHelper.write(payloadBytes); - ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise); - } else { - ctx.write(msg, promise); - } - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) - throws NoiseHandshakeException { - if (message instanceof ByteBuf frame) { - try { - final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame)); - final NoiseTunnelProtos.HandshakeResponse handshakeResponse = - NoiseTunnelProtos.HandshakeResponse.parseFrom(payload); - - context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split())); - context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse)); - } catch (InvalidProtocolBufferException e) { - throw new NoiseHandshakeException("Failed to parse handshake response"); - } finally { - frame.release(); - } - } else { - context.fireChannelRead(message); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - handshakeHelper.destroy(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java deleted file mode 100644 index df95bf762..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientHandshakeHelper.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import com.southernstorm.noise.protocol.CipherStatePair; -import com.southernstorm.noise.protocol.HandshakeState; -import java.security.NoSuchAlgorithmException; -import javax.crypto.BadPaddingException; -import javax.crypto.ShortBufferException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException; - -public class NoiseClientHandshakeHelper { - - private final HandshakePattern handshakePattern; - private final HandshakeState handshakeState; - - private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) { - this.handshakePattern = handshakePattern; - this.handshakeState = handshakeState; - } - - public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) { - try { - final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR); - state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0); - state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0); - state.start(); - return new NoiseClientHandshakeHelper(HandshakePattern.IK, state); - } catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException(e); - } - } - - public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) { - try { - final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR); - state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0); - state.start(); - return new NoiseClientHandshakeHelper(HandshakePattern.NK, state); - } catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException(e); - } - } - - public byte[] write(final byte[] requestPayload) throws ShortBufferException { - final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16]; - handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length); - return initiateHandshakeMessage; - } - - private int initiateHandshakeKeysLength() { - return switch (handshakePattern) { - // 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag - case IK -> 32 + 32 + 16; - // 32-byte ephemeral key - case NK -> 32; - }; - } - - public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException { - // Don't process additional messages if the handshake failed and we're just waiting to close - if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) { - throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction()); - } - final int payloadLength = responderHandshakeMessage.length - 16 - 32; - final byte[] responsePayload = new byte[payloadLength]; - final int payloadBytesRead; - try { - payloadBytesRead = handshakeState - .readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0); - if (payloadBytesRead != responsePayload.length) { - throw new IllegalStateException( - "Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead); - } - return responsePayload; - } catch (ShortBufferException e) { - throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage()); - } catch (BadPaddingException e) { - throw new NoiseHandshakeException(e.getMessage()); - } - } - - public CipherStatePair split() { - return this.handshakeState.split(); - } - - public void destroy() { - this.handshakeState.destroy(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java deleted file mode 100644 index b6b6f0a24..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseClientTransportHandler.java +++ /dev/null @@ -1,89 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import com.southernstorm.noise.protocol.CipherState; -import com.southernstorm.noise.protocol.CipherStatePair; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.util.ReferenceCountUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; - -/** - * A Noise transport handler manages a bidirectional Noise session after a handshake has completed. - */ -public class NoiseClientTransportHandler extends ChannelDuplexHandler { - - private final CipherStatePair cipherStatePair; - - private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class); - - NoiseClientTransportHandler(CipherStatePair cipherStatePair) { - this.cipherStatePair = cipherStatePair; - } - - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - try { - if (message instanceof ByteBuf frame) { - final CipherState cipherState = cipherStatePair.getReceiver(); - - // We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array. - // We'll need to copy it to a heap buffer. - final byte[] noiseBuffer = ByteBufUtil.getBytes(frame); - - // Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer - final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length); - - context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength)); - } else { - // Anything except binary frames should have been filtered out of the pipeline by now; treat this as an - // error - throw new IllegalArgumentException("Unexpected message in pipeline: " + message); - } - } finally { - ReferenceCountUtil.release(message); - } - } - - - @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) - throws Exception { - if (message instanceof ByteBuf plaintext) { - try { - final CipherState cipherState = cipherStatePair.getSender(); - final int plaintextLength = plaintext.readableBytes(); - - // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the - // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a - // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. - final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; - plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); - - // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer - cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); - - context.write(Unpooled.wrappedBuffer(noiseBuffer), promise); - } finally { - ReferenceCountUtil.release(plaintext); - } - } else { - if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) { - // Clients only write ByteBufs or a close frame on errors, so any other message is unexpected - log.warn("Unexpected object in pipeline: {}", message); - } - context.write(message, promise); - } - } - - @Override - public void handlerRemoved(final ChannelHandlerContext context) { - cipherStatePair.destroy(); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java deleted file mode 100644 index 45383318b..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/NoiseTunnelClient.java +++ /dev/null @@ -1,408 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -import com.google.protobuf.ByteString; -import com.southernstorm.noise.protocol.Noise; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.ChannelOutboundHandlerAdapter; -import io.netty.channel.ChannelPromise; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; -import io.netty.handler.codec.MessageToMessageCodec; -import io.netty.handler.codec.haproxy.HAProxyMessage; -import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpClientCodec; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpObjectAggregator; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; -import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus; -import io.netty.handler.codec.http.websocketx.WebSocketVersion; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.util.ReferenceCountUtil; -import java.net.SocketAddress; -import java.net.URI; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; -import java.util.function.Supplier; -import javax.net.ssl.SSLException; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.grpc.net.FramingType; -import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec; -import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos; -import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec; -import org.whispersystems.textsecuregcm.util.UUIDUtil; - -public class NoiseTunnelClient implements AutoCloseable { - - private final CompletableFuture closeEventFuture; - private final CompletableFuture handshakeEventFuture; - private final CompletableFuture userCloseFuture; - private final ServerBootstrap serverBootstrap; - private Channel serverChannel; - - public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated"); - public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous"); - - public static class Builder { - - final SocketAddress remoteServerAddress; - NioEventLoopGroup eventLoopGroup; - ECPublicKey serverPublicKey; - - FramingType framingType = FramingType.WEBSOCKET; - URI websocketUri = ANONYMOUS_WEBSOCKET_URI; - HttpHeaders headers = new DefaultHttpHeaders(); - NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder(); - - boolean authenticated = false; - ECKeyPair ecKeyPair = null; - boolean useTls; - X509Certificate trustedServerCertificate = null; - Supplier proxyMessageSupplier = null; - - public Builder( - final SocketAddress remoteServerAddress, - final NioEventLoopGroup eventLoopGroup, - final ECPublicKey serverPublicKey) { - this.remoteServerAddress = remoteServerAddress; - this.eventLoopGroup = eventLoopGroup; - this.serverPublicKey = serverPublicKey; - } - - public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) { - this.authenticated = true; - handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier)); - handshakeInit.setDeviceId(deviceId); - this.ecKeyPair = ecKeyPair; - this.websocketUri = AUTHENTICATED_WEBSOCKET_URI; - return this; - } - - public Builder setWebsocketUri(final URI websocketUri) { - this.websocketUri = websocketUri; - return this; - } - - public Builder setUseTls(X509Certificate trustedServerCertificate) { - this.useTls = true; - this.trustedServerCertificate = trustedServerCertificate; - return this; - } - - public Builder setProxyMessageSupplier(Supplier proxyMessageSupplier) { - this.proxyMessageSupplier = proxyMessageSupplier; - return this; - } - - public Builder setUserAgent(final String userAgent) { - handshakeInit.setUserAgent(userAgent); - return this; - } - - public Builder setAcceptLanguage(final String acceptLanguage) { - handshakeInit.setAcceptLanguage(acceptLanguage); - return this; - } - - public Builder setHeaders(final HttpHeaders headers) { - this.headers = headers; - return this; - } - - public Builder setServerPublicKey(ECPublicKey serverPublicKey) { - this.serverPublicKey = serverPublicKey; - return this; - } - - public Builder setFramingType(FramingType framingType) { - this.framingType = framingType; - return this; - } - - public NoiseTunnelClient build() { - final List handlers = new ArrayList<>(); - if (proxyMessageSupplier != null) { - handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier))); - } - if (useTls) { - final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); - - if (trustedServerCertificate != null) { - sslContextBuilder.trustManager(trustedServerCertificate); - } - - try { - handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT)); - } catch (SSLException e) { - throw new IllegalArgumentException(e); - } - } - - // handles the wrapping and unrwrapping the framing layer (websockets or noisedirect) - handlers.addAll(switch (framingType) { - case WEBSOCKET -> websocketHandlerStack(websocketUri, headers); - case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated); - }); - - final NoiseClientHandshakeHelper helper = authenticated - ? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair) - : NoiseClientHandshakeHelper.NK(serverPublicKey); - - handlers.add(new NoiseClientHandshakeHandler(helper)); - - // When the noise handshake completes we'll save the response from the server so client users can inspect it - final UserEventFuture handshakeEventHandler = - new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class); - handlers.add(handshakeEventHandler); - - // Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off - // information about why the connection was closed. - final UserEventFuture closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class); - handlers.add(closeEventHandler); - - // When the user closes the client, write a normal closure close frame - final CompletableFuture userCloseFuture = new CompletableFuture<>(); - handlers.add(new ChannelInboundHandlerAdapter() { - @Override - public void handlerAdded(final ChannelHandlerContext ctx) { - userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) { - case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE); - case NOISE_DIRECT -> new NoiseDirectFrame( - NoiseDirectFrame.FrameType.CLOSE, - Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason - .newBuilder() - .setCode(NoiseDirectProtos.CloseReason.Code.OK) - .build() - .toByteArray())); - }) - .addListener(ChannelFutureListener.CLOSE), - ctx.executor()); - } - }); - - final NoiseTunnelClient client = - new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler( - handlers, - remoteServerAddress, - handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build())); - client.start(); - return client; - } - } - - private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup, - CompletableFuture closeEventFuture, - CompletableFuture handshakeEventFuture, - CompletableFuture userCloseFuture, - Function handler) { - - this.userCloseFuture = userCloseFuture; - this.closeEventFuture = closeEventFuture; - this.handshakeEventFuture = handshakeEventFuture; - this.serverBootstrap = new ServerBootstrap() - .localAddress(new LocalAddress("websocket-noise-tunnel-client")) - .channel(LocalServerChannel.class) - .group(eventLoopGroup) - .childHandler(new ChannelInitializer() { - @Override - protected void initChannel(final LocalChannel localChannel) { - localChannel.pipeline() - // We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the - // stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put - // in the handshake. - .addLast(Http2Buffering.handler()) - // Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At - // that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually - // connect to the remote service - .addLast(new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { - if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) { - byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest); - fastOpenRequest.release(); - ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes)); - } - super.userEventTriggered(ctx, evt); - } - }) - .addLast(new ClientErrorHandler()); - } - }); - } - - private static class UserEventFuture extends ChannelInboundHandlerAdapter { - private final CompletableFuture future = new CompletableFuture<>(); - private final Class cls; - - UserEventFuture(Class cls) { - this.cls = cls; - } - - @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (cls.isInstance(evt)) { - future.complete((T) evt); - } - ctx.fireUserEventTriggered(evt); - } - } - - - public LocalAddress getLocalAddress() { - return (LocalAddress) serverChannel.localAddress(); - } - - private NoiseTunnelClient start() { - serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel(); - return this; - } - - @Override - public void close() throws InterruptedException { - userCloseFuture.complete(null); - serverChannel.close().await(); - } - - /** - * @return A future that completes when a close frame is observed - */ - public CompletableFuture closeFrameFuture() { - return closeEventFuture; - } - - /** - * @return A future that completes when the noise handshake finishes - */ - public CompletableFuture getHandshakeEventFuture() { - return handshakeEventFuture; - } - - - private static List noiseDirectHandlerStack(boolean authenticated) { - return List.of( - new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2), - new NoiseDirectFrameCodec(), - new ChannelDuplexHandler() { - @Override - public void channelActive(ChannelHandlerContext ctx) { - ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent()); - ctx.fireChannelActive(); - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { - try { - final NoiseDirectProtos.CloseReason closeReason = - NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content())); - ctx.fireUserEventTriggered( - CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER)); - } finally { - ReferenceCountUtil.release(msg); - } - } else { - ctx.fireChannelRead(msg); - } - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) { - final NoiseDirectProtos.CloseReason errorPayload = - NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content())); - ctx.fireUserEventTriggered( - CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT)); - } - ctx.write(msg, promise); - } - }, - new MessageToMessageCodec() { - boolean noiseHandshakeFinished = false; - - @Override - protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List out) { - final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished - ? NoiseDirectFrame.FrameType.DATA - : (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE); - noiseHandshakeFinished = true; - out.add(new NoiseDirectFrame(frameType, msg.retain())); - } - - @Override - protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg, - final List out) { - out.add(msg.content().retain()); - } - }); - } - - private static List websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) { - return List.of( - new HttpClientCodec(), - new HttpObjectAggregator(Noise.MAX_PACKET_LEN), - // Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we - // want to react to them on our own, we need to catch them before they hit that handler. - new ChannelInboundHandlerAdapter() { - @Override - public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception { - if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { - context.fireUserEventTriggered( - CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER)); - } - - super.channelRead(context, message); - } - }, - new WebSocketClientProtocolHandler(websocketUri, - WebSocketVersion.V13, - null, - false, - headers, - Noise.MAX_PACKET_LEN, - 10_000), - new ChannelOutboundHandlerAdapter() { - @Override - public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { - if (message instanceof CloseWebSocketFrame closeWebSocketFrame) { - context.fireUserEventTriggered( - CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT)); - } - super.write(context, message, promise); - } - }, - new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) { - if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { - context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent()); - } - } - context.fireUserEventTriggered(event); - } - }, - new WebsocketPayloadCodec()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java deleted file mode 100644 index 90478599e..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/client/ReadyForNoiseHandshakeEvent.java +++ /dev/null @@ -1,4 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.client; - -public record ReadyForNoiseHandshakeEvent() { -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java deleted file mode 100644 index b826565d7..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/noisedirect/DirectNoiseTunnelServerIntegrationTest.java +++ /dev/null @@ -1,50 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.noisedirect; - -import io.netty.channel.local.LocalAddress; -import io.netty.channel.nio.NioEventLoopGroup; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.grpc.net.FramingType; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; -import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -import java.util.concurrent.Executor; - -class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { - private NoiseDirectTunnelServer noiseDirectTunnelServer; - - @Override - protected void start( - final NioEventLoopGroup eventLoopGroup, - final Executor delegatedTaskExecutor, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair serverKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final String recognizedProxySecret) throws Exception { - - noiseDirectTunnelServer = new NoiseDirectTunnelServer(0, - eventLoopGroup, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress); - noiseDirectTunnelServer.start(); - } - - @Override - protected void stop() throws InterruptedException { - noiseDirectTunnelServer.stop(); - } - - @Override - protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) { - return new NoiseTunnelClient - .Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey) - .setFramingType(FramingType.NOISE_DIRECT); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java deleted file mode 100644 index 1765e9c14..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/RejectUnsupportedMessagesHandlerTest.java +++ /dev/null @@ -1,72 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; -import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; -import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; -import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; -import io.netty.handler.codec.http.websocketx.WebSocketFrame; -import java.util.List; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; - -class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest { - - private EmbeddedChannel embeddedChannel; - - @BeforeEach - void setUp() { - embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler()); - } - - @ParameterizedTest - @MethodSource - void allowWebSocketFrame(final WebSocketFrame frame) { - embeddedChannel.writeOneInbound(frame); - - try { - assertEquals(frame, embeddedChannel.inboundMessages().poll()); - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - assertEquals(1, frame.refCnt()); - } finally { - frame.release(); - } - } - - private static List allowWebSocketFrame() { - return List.of( - new BinaryWebSocketFrame(), - new CloseWebSocketFrame(), - new ContinuationWebSocketFrame(), - new PingWebSocketFrame(), - new PongWebSocketFrame()); - } - - @Test - void rejectTextFrame() { - final TextWebSocketFrame textFrame = new TextWebSocketFrame(); - embeddedChannel.writeOneInbound(textFrame); - - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - assertEquals(0, textFrame.refCnt()); - } - - @Test - void rejectNonWebSocketFrame() { - final ByteBuf bytes = Unpooled.buffer(0); - embeddedChannel.writeOneInbound(bytes); - - assertTrue(embeddedChannel.inboundMessages().isEmpty()); - assertEquals(0, bytes.refCnt()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java deleted file mode 100644 index a38c9181e..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/TlsWebSocketNoiseTunnelServerIntegrationTest.java +++ /dev/null @@ -1,239 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import io.grpc.ManagedChannel; -import io.grpc.Status; -import io.netty.channel.local.LocalAddress; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpHeaders; -import java.io.ByteArrayInputStream; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.security.KeyFactory; -import java.security.KeyStore; -import java.security.PrivateKey; -import java.security.SecureRandom; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; -import java.security.spec.PKCS8EncodedKeySpec; -import java.util.Base64; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeoutException; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; -import org.junit.jupiter.api.Test; -import org.signal.chat.rpc.GetRequestAttributesRequest; -import org.signal.chat.rpc.GetRequestAttributesResponse; -import org.signal.chat.rpc.RequestAttributesGrpc; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils; -import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { - private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer; - private X509Certificate serverTlsCertificate; - - - // Please note that this certificate/key are used only for testing and are not used anywhere outside of this test. - // They were generated with: - // - // ```shell - // openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost" - // ``` - private static final String SERVER_CERTIFICATE = """ - -----BEGIN CERTIFICATE----- - MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw - FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx - MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA - IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV - jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq - SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME - GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG - SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw - XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi - iOr9sHiO8Rn2u0xRKgU5Ig== - -----END CERTIFICATE----- - """; - - // BEGIN/END PRIVATE KEY header/footer removed for easier parsing - private static final String SERVER_PRIVATE_KEY = """ - MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj - kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd - PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA - O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo= - """; - @Override - protected void start( - final NioEventLoopGroup eventLoopGroup, - final Executor delegatedTaskExecutor, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair serverKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final String recognizedProxySecret) throws Exception { - final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); - serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate( - new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8))); - final PrivateKey serverTlsPrivateKey; - final KeyFactory keyFactory = KeyFactory.getInstance("EC"); - serverTlsPrivateKey = - keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY))); - tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, - new X509Certificate[]{serverTlsCertificate}, - serverTlsPrivateKey, - eventLoopGroup, - delegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress, - recognizedProxySecret); - tlsNoiseWebSocketTunnelServer.start(); - } - - @Override - protected void stop() throws InterruptedException { - tlsNoiseWebSocketTunnelServer.stop(); - } - - @Override - protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, - final ECPublicKey serverPublicKey) { - return new NoiseTunnelClient - .Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey) - .setUseTls(serverTlsCertificate); - } - - @Test - void getRequestAttributes() throws InterruptedException { - final String remoteAddress = "4.5.6.7"; - final String acceptLanguage = "en"; - final String userAgent = "Signal-Desktop/1.2.3 Linux"; - - final HttpHeaders headers = new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add("X-Forwarded-For", remoteAddress); - - try (final NoiseTunnelClient client = anonymous() - .setHeaders(headers) - .setUserAgent(userAgent) - .setAcceptLanguage(acceptLanguage) - .build()) { - - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()); - - assertEquals(remoteAddress, response.getRemoteAddress()); - assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList()); - assertEquals(userAgent, response.getUserAgent()); - } finally { - channel.shutdown(); - } - } - } - - @Test - void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException { - try (final NoiseTunnelClient client = authenticated() - .setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - } - - - @Test - void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException { - try (final NoiseTunnelClient client = anonymous() - .setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI) - .build()) { - final ManagedChannel channel = buildManagedChannel(client.getLocalAddress()); - - try { - //noinspection ResultOfMethodCallIgnored - GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, - () -> RequestAttributesGrpc.newBlockingStub(channel) - .getRequestAttributes(GetRequestAttributesRequest.newBuilder().build())); - } finally { - channel.shutdown(); - } - assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR); - } - } - - @Test - void rejectIllegalRequests() throws Exception { - - final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); - keyStore.load(null, null); - keyStore.setCertificateEntry("tunnel", serverTlsCertificate); - - final TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - - trustManagerFactory.init(keyStore); - - final SSLContext sslContext = SSLContext.getInstance("TLS"); - sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); - - final URI authenticatedUri = - new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated", - null, null); - - final URI incorrectUri = - new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect", - null, null); - - try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) { - assertEquals(405, httpClient.send(HttpRequest.newBuilder() - .uri(authenticatedUri) - .PUT(HttpRequest.BodyPublishers.ofString("test")) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "Non-GET requests should not be allowed"); - - assertEquals(426, httpClient.send(HttpRequest.newBuilder() - .GET() - .uri(authenticatedUri) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "GET requests without upgrade headers should not be allowed"); - - assertEquals(404, httpClient.send(HttpRequest.newBuilder() - .GET() - .uri(incorrectUri) - .build(), - HttpResponse.BodyHandlers.ofString()).statusCode(), - "GET requests to unrecognized URIs should not be allowed"); - } - } - - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java deleted file mode 100644 index 2efb0e72c..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketNoiseTunnelServerIntegrationTest.java +++ /dev/null @@ -1,50 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import io.netty.channel.local.LocalAddress; -import io.netty.channel.nio.NioEventLoopGroup; -import java.util.concurrent.Executor; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest; -import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager; -import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient; -import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; - -class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest { - private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer; - - @Override - protected void start( - final NioEventLoopGroup eventLoopGroup, - final Executor delegatedTaskExecutor, - final GrpcClientConnectionManager grpcClientConnectionManager, - final ClientPublicKeysManager clientPublicKeysManager, - final ECKeyPair serverKeyPair, - final LocalAddress authenticatedGrpcServerAddress, - final LocalAddress anonymousGrpcServerAddress, - final String recognizedProxySecret) throws Exception { - plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0, - null, - null, - eventLoopGroup, - delegatedTaskExecutor, - grpcClientConnectionManager, - clientPublicKeysManager, - serverKeyPair, - authenticatedGrpcServerAddress, - anonymousGrpcServerAddress, - recognizedProxySecret); - plaintextNoiseWebSocketTunnelServer.start(); - } - - @Override - protected void stop() throws InterruptedException { - plaintextNoiseWebSocketTunnelServer.stop(); - } - - @Override - protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) { - return new NoiseTunnelClient - .Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java deleted file mode 100644 index 8bd1fad4f..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebSocketOpeningHandshakeHandlerTest.java +++ /dev/null @@ -1,115 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - -import io.netty.buffer.Unpooled; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; - -class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest { - - private EmbeddedChannel embeddedChannel; - - private static final String AUTHENTICATED_PATH = "/authenticated"; - private static final String ANONYMOUS_PATH = "/anonymous"; - private static final String HEALTH_CHECK_PATH = "/health-check"; - - @BeforeEach - void setUp() { - embeddedChannel = - new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH, HEALTH_CHECK_PATH)); - } - - @ParameterizedTest - @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) - void handleValidRequest(final String path) { - final FullHttpRequest request = buildRequest(HttpMethod.GET, path, - new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); - - try { - embeddedChannel.writeOneInbound(request); - - assertEquals(1, request.refCnt()); - assertEquals(1, embeddedChannel.inboundMessages().size()); - assertEquals(request, embeddedChannel.inboundMessages().poll()); - } finally { - request.release(); - } - } - - @Test - void handleHealthCheckRequest() { - final FullHttpRequest request = buildRequest(HttpMethod.GET, HEALTH_CHECK_PATH, new DefaultHttpHeaders()); - - embeddedChannel.writeOneInbound(request); - - assertEquals(0, request.refCnt()); - assertHttpResponse(HttpResponseStatus.NO_CONTENT); - } - - @ParameterizedTest - @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) - void handleUpgradeRequired(final String path) { - final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders()); - - embeddedChannel.writeOneInbound(request); - - assertEquals(0, request.refCnt()); - assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED); - } - - @Test - void handleBadPath() { - final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect", - new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); - - embeddedChannel.writeOneInbound(request); - - assertEquals(0, request.refCnt()); - assertHttpResponse(HttpResponseStatus.NOT_FOUND); - } - - @ParameterizedTest - @ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH }) - void handleMethodNotAllowed(final String path) { - final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path, - new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)); - - embeddedChannel.writeOneInbound(request); - - assertEquals(0, request.refCnt()); - assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED); - } - - private void assertHttpResponse(final HttpResponseStatus expectedStatus) { - assertEquals(1, embeddedChannel.outboundMessages().size()); - - final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll()); - - assertEquals(expectedStatus, response.status()); - } - - private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) { - return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, - method, - path, - Unpooled.buffer(0), - headers, - new DefaultHttpHeaders()); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java deleted file mode 100644 index e27311489..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/websocket/WebsocketHandshakeCompleteHandlerTest.java +++ /dev/null @@ -1,233 +0,0 @@ -package org.whispersystems.textsecuregcm.grpc.net.websocket; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.params.provider.Arguments.argumentSet; -import static org.junit.jupiter.params.provider.Arguments.arguments; - -import com.google.common.net.InetAddresses; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.channel.local.LocalAddress; -import io.netty.handler.codec.http.DefaultHttpHeaders; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.stream.Stream; -import javax.annotation.Nullable; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest; -import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern; -import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit; -import org.whispersystems.textsecuregcm.util.TestRandomUtil; - -class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest { - - private UserEventRecordingHandler userEventRecordingHandler; - private MutableRemoteAddressEmbeddedChannel embeddedChannel; - - private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16); - - private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter { - - private final List receivedEvents = new ArrayList<>(); - - @Override - public void userEventTriggered(final ChannelHandlerContext context, final Object event) { - receivedEvents.add(event); - } - - public List getReceivedEvents() { - return receivedEvents; - } - } - - private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel { - - private SocketAddress remoteAddress; - - public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) { - super(handlers); - } - - @Override - protected SocketAddress remoteAddress0() { - return isActive() ? remoteAddress : null; - } - - public void setRemoteAddress(final SocketAddress remoteAddress) { - this.remoteAddress = remoteAddress; - } - } - - @BeforeEach - void setUp() { - userEventRecordingHandler = new UserEventRecordingHandler(); - - embeddedChannel = new MutableRemoteAddressEmbeddedChannel( - new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET), - userEventRecordingHandler); - - embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0)); - } - - @ParameterizedTest - @MethodSource - void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) { - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = - new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null); - - embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents()); - - final byte[] payload = TestRandomUtil.nextBytes(100); - embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload)); - assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); - final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll(); - assertNotNull(init); - assertEquals(init.getHandshakePattern(), pattern); - } - - private static List handleWebSocketHandshakeComplete() { - return List.of( - Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK), - Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK)); - } - - @Test - void handleWebSocketHandshakeCompleteUnexpectedPath() { - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = - new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null); - - embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); - assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException()); - } - - @Test - void handleUnrecognizedEvent() { - final Object unrecognizedEvent = new Object(); - - embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent); - assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents()); - } - - @ParameterizedTest - @MethodSource - void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) { - final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent = - new WebSocketServerProtocolHandler.HandshakeComplete( - NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null); - - embeddedChannel.setRemoteAddress(remoteAddress); - embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent); - - final byte[] payload = TestRandomUtil.nextBytes(100); - embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload)); - final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll(); - assertEquals( - expectedRemoteAddress, - Optional.ofNullable(init) - .map(NoiseHandshakeInit::getRemoteAddress) - .orElse(null)); - if (expectedRemoteAddress == null) { - assertThrows(IllegalStateException.class, embeddedChannel::checkException); - } else { - assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class)); - } - } - - private static List getRemoteAddress() { - final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0); - final InetAddress clientAddress = InetAddresses.forString("1.2.3.4"); - final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1"); - - return List.of( - argumentSet("Recognized proxy, single forwarded-for address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), - remoteAddress, - clientAddress), - - argumentSet("Recognized proxy, multiple forwarded-for addresses", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()), - remoteAddress, - proxyAddress), - - argumentSet("No recognized proxy header, single forwarded-for address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), - remoteAddress, - remoteAddress.getAddress()), - - argumentSet("No recognized proxy header, no forwarded-for address", - new DefaultHttpHeaders(), - remoteAddress, - remoteAddress.getAddress()), - - argumentSet("Incorrect proxy header, single forwarded-for address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect") - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()), - remoteAddress, - remoteAddress.getAddress()), - - argumentSet("Recognized proxy, no forwarded-for address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), - remoteAddress, - remoteAddress.getAddress()), - - argumentSet("Recognized proxy, bogus forwarded-for address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET) - .add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"), - remoteAddress, - null), - - argumentSet("No forwarded-for address, non-InetSocketAddress remote address", - new DefaultHttpHeaders() - .add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET), - new LocalAddress("local-address"), - null) - ); - } - - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - @ParameterizedTest - @MethodSource("argumentsForGetMostRecentProxy") - void getMostRecentProxy(final String forwardedFor, final Optional expectedMostRecentProxy) { - assertEquals(expectedMostRecentProxy, WebsocketHandshakeCompleteHandler.getMostRecentProxy(forwardedFor)); - } - - private static Stream argumentsForGetMostRecentProxy() { - return Stream.of( - arguments(null, Optional.empty()), - arguments("", Optional.empty()), - arguments(" ", Optional.empty()), - arguments("203.0.113.195,", Optional.empty()), - arguments("203.0.113.195, ", Optional.empty()), - arguments("203.0.113.195", Optional.of("203.0.113.195")), - arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178")) - ); - } -} diff --git a/service/src/test/resources/config/test-secrets-bundle.yml b/service/src/test/resources/config/test-secrets-bundle.yml index bfa7c7490..99fd1e587 100644 --- a/service/src/test/resources/config/test-secrets-bundle.yml +++ b/service/src/test/resources/config/test-secrets-bundle.yml @@ -172,8 +172,3 @@ turn.cloudflare.apiToken: ABCDEFGHIJKLM linkDevice.secret: AAAAAAAAAAA= tlsKeyStore.password: unset - -# The below private key was generated exclusively for testing purposes. Do not use it in any other context. -# Corresponding public key: cYUAFtkWK/4x3AfW/yw7qgIo/mQUaRSWaPolGQkiL14= -noiseTunnel.noiseStaticPrivateKey: qK5FD9WmuhoLPsS/Z4swcZkwDn9OpeM5ZmcEVMpEQ24= -noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA diff --git a/service/src/test/resources/config/test.yml b/service/src/test/resources/config/test.yml index 07a52741f..ea6893f43 100644 --- a/service/src/test/resources/config/test.yml +++ b/service/src/test/resources/config/test.yml @@ -525,12 +525,6 @@ turn: linkDevice: secret: secret://linkDevice.secret -noiseTunnel: - webSocketPort: 8444 - directPort: 8445 - noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey - recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret - externalRequestFilter: grpcMethods: - com.example.grpc.ExampleService/exampleMethod @@ -541,3 +535,6 @@ externalRequestFilter: idlePrimaryDeviceReminder: minIdleDuration: P30D + +grpc: + port: 50051