Add direct grpc server

This commit is contained in:
ravi-signal
2025-10-06 15:22:36 -05:00
committed by GitHub
parent 9751569dc7
commit a2f2fc93b0
87 changed files with 546 additions and 6469 deletions

View File

@@ -39,13 +39,13 @@ import org.whispersystems.textsecuregcm.configuration.FcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration;
import org.whispersystems.textsecuregcm.configuration.GenericZkConfig;
import org.whispersystems.textsecuregcm.configuration.GooglePlayBillingConfiguration;
import org.whispersystems.textsecuregcm.configuration.GrpcConfiguration;
import org.whispersystems.textsecuregcm.configuration.IdlePrimaryDeviceReminderConfiguration;
import org.whispersystems.textsecuregcm.configuration.KeyTransparencyServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.LinkDeviceSecretConfiguration;
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.OpenTelemetryConfiguration;
import org.whispersystems.textsecuregcm.configuration.PagedSingleUseKEMPreKeyStoreConfiguration;
@@ -315,11 +315,6 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration();
@Valid
@NotNull
@JsonProperty
private NoiseTunnelConfiguration noiseTunnel;
@Valid
@NotNull
@JsonProperty
@@ -348,6 +343,11 @@ public class WhisperServerConfiguration extends Configuration {
@NotNull
private RetryConfiguration generalRedisRetry = new RetryConfiguration();
@NotNull
@Valid
@JsonProperty
private GrpcConfiguration grpc;
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
return tlsKeyStore;
}
@@ -551,10 +551,6 @@ public class WhisperServerConfiguration extends Configuration {
return virtualThread;
}
public NoiseTunnelConfiguration getNoiseTunnelConfiguration() {
return noiseTunnel;
}
public ExternalRequestFilterConfiguration getExternalRequestFilterConfiguration() {
return externalRequestFilter;
}
@@ -582,4 +578,8 @@ public class WhisperServerConfiguration extends Configuration {
public RetryConfiguration getGeneralRedisRetryConfiguration() {
return generalRedisRetry;
}
public GrpcConfiguration getGrpc() {
return grpc;
}
}

View File

@@ -23,12 +23,14 @@ import io.dropwizard.core.setup.Environment;
import io.dropwizard.jetty.HttpsConnectorFactory;
import io.dropwizard.lifecycle.setup.LifecycleEnvironment;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition;
import io.grpc.netty.NettyServerBuilder;
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
import io.lettuce.core.metrics.MicrometerOptions;
import io.lettuce.core.resource.ClientResources;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.resolver.ResolvedAddressTypes;
@@ -38,16 +40,12 @@ import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletRegistration;
import java.io.ByteArrayInputStream;
import java.io.FileInputStream;
import java.net.InetSocketAddress;
import java.net.http.HttpClient;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
@@ -62,7 +60,6 @@ import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Function;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.core.WebSocketComponents;
import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
@@ -154,12 +151,8 @@ import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.ValidatingInterceptor;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.grpc.net.ManagedGrpcServer;
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectTunnelServer;
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
@@ -259,9 +252,9 @@ import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager;
import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
import org.whispersystems.textsecuregcm.util.ResilienceUtil;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import org.whispersystems.textsecuregcm.util.ManagedExecutors;
import org.whispersystems.textsecuregcm.util.ResilienceUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
@@ -636,9 +629,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
() -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, retryExecutor, config.getSecureStorageServiceConfiguration());
final GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient,
grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor);
disconnectionRequestListenerExecutor, retryExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client,
config.getCdnConfiguration().bucket());
MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler,
@@ -828,127 +820,63 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getAppleDeviceCheck().teamId(),
config.getAppleDeviceCheck().bundleId());
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
final MetricServerInterceptor metricServerInterceptor = new MetricServerInterceptor(Metrics.globalRegistry, clientReleaseManager);
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
final RequestAttributesInterceptor requestAttributesInterceptor =
new RequestAttributesInterceptor(grpcClientConnectionManager);
final RequestAttributesInterceptor requestAttributesInterceptor = new RequestAttributesInterceptor();
final ValidatingInterceptor validatingInterceptor = new ValidatingInterceptor();
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
final ExternalRequestFilter grpcExternalRequestFilter = new ExternalRequestFilter(
config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
config.getExternalRequestFilterConfiguration().grpcMethods());
final RequireAuthenticationInterceptor requireAuthenticationInterceptor = new RequireAuthenticationInterceptor(accountAuthenticator);
final ProhibitAuthenticationInterceptor prohibitAuthenticationInterceptor = new ProhibitAuthenticationInterceptor();
final ManagedLocalGrpcServer anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, localEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
serverBuilder
.intercept(
new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
config.getExternalRequestFilterConfiguration().grpcMethods()))
.intercept(validatingInterceptor)
.intercept(metricServerInterceptor)
.intercept(errorMappingInterceptor)
.intercept(remoteDeprecationFilter)
.intercept(requestAttributesInterceptor)
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager))
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC()))
.addService(new PaymentsGrpcService(currencyManager))
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams));
}
};
final List<ServerServiceDefinition> authenticatedServices = Stream.of(
new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager),
ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters),
new KeysGrpcService(accountsManager, keysManager, rateLimiters),
new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations))
.map(bindableService -> ServerInterceptors.intercept(bindableService,
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
validatingInterceptor,
metricServerInterceptor,
errorMappingInterceptor,
remoteDeprecationFilter,
requestAttributesInterceptor,
requireAuthenticationInterceptor))
.toList();
final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
serverBuilder
.intercept(validatingInterceptor)
.intercept(metricServerInterceptor)
.intercept(errorMappingInterceptor)
.intercept(remoteDeprecationFilter)
.intercept(requestAttributesInterceptor)
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager))
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))
.addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations));
}
};
final List<ServerServiceDefinition> unauthenticatedServices = Stream.of(
new AccountsAnonymousGrpcService(accountsManager, rateLimiters),
new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC()),
new PaymentsGrpcService(currencyManager),
ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config),
new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams))
.map(bindableService -> ServerInterceptors.intercept(bindableService,
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
grpcExternalRequestFilter,
validatingInterceptor,
metricServerInterceptor,
errorMappingInterceptor,
remoteDeprecationFilter,
requestAttributesInterceptor,
prohibitAuthenticationInterceptor))
.toList();
@Nullable final X509Certificate[] noiseWebSocketTlsCertificateChain;
@Nullable final PrivateKey noiseWebSocketTlsPrivateKey;
final ServerBuilder<?> serverBuilder =
NettyServerBuilder.forAddress(new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port()));
authenticatedServices.forEach(serverBuilder::addService);
unauthenticatedServices.forEach(serverBuilder::addService);
final ManagedGrpcServer exposedGrpcServer = new ManagedGrpcServer(serverBuilder.build());
if (config.getNoiseTunnelConfiguration().tlsKeyStoreFile() != null &&
config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias() != null &&
config.getNoiseTunnelConfiguration().tlsKeyStorePassword() != null) {
try (final FileInputStream websocketNoiseTunnelTlsKeyStoreInputStream = new FileInputStream(config.getNoiseTunnelConfiguration().tlsKeyStoreFile())) {
final KeyStore keyStore = KeyStore.getInstance("PKCS12");
keyStore.load(websocketNoiseTunnelTlsKeyStoreInputStream, config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray());
final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) keyStore.getEntry(
config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias(),
new KeyStore.PasswordProtection(config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray()));
noiseWebSocketTlsCertificateChain =
Arrays.copyOf(privateKeyEntry.getCertificateChain(), privateKeyEntry.getCertificateChain().length, X509Certificate[].class);
noiseWebSocketTlsPrivateKey = privateKeyEntry.getPrivateKey();
}
} else {
noiseWebSocketTlsCertificateChain = null;
noiseWebSocketTlsPrivateKey = null;
}
final ExecutorService noiseWebSocketDelegatedTaskExecutor = ExecutorServiceBuilder.of(environment, "noiseWebsocketDelegatedTask")
.minThreads(8)
.maxThreads(8)
.allowCoreThreadTimeOut(false)
.build();
final ManagedNioEventLoopGroup noiseTunnelEventLoopGroup = new ManagedNioEventLoopGroup();
final NoiseWebSocketTunnelServer noiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(
config.getNoiseTunnelConfiguration().webSocketPort(),
noiseWebSocketTlsCertificateChain,
noiseWebSocketTlsPrivateKey,
noiseTunnelEventLoopGroup,
noiseWebSocketDelegatedTaskExecutor,
grpcClientConnectionManager,
clientPublicKeysManager,
config.getNoiseTunnelConfiguration().noiseStaticKeyPair(),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress,
config.getNoiseTunnelConfiguration().recognizedProxySecret().value());
final NoiseDirectTunnelServer noiseDirectTunnelServer = new NoiseDirectTunnelServer(
config.getNoiseTunnelConfiguration().directPort(),
noiseTunnelEventLoopGroup,
grpcClientConnectionManager,
clientPublicKeysManager,
config.getNoiseTunnelConfiguration().noiseStaticKeyPair(),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress);
environment.lifecycle().manage(localEventLoopGroup);
environment.lifecycle().manage(dnsResolutionEventLoopGroup);
environment.lifecycle().manage(anonymousGrpcServer);
environment.lifecycle().manage(authenticatedGrpcServer);
environment.lifecycle().manage(noiseTunnelEventLoopGroup);
environment.lifecycle().manage(noiseWebSocketTunnelServer);
environment.lifecycle().manage(noiseDirectTunnelServer);
environment.lifecycle().manage(exposedGrpcServer);
final List<Filter> filters = new ArrayList<>();
filters.add(remoteDeprecationFilter);

View File

@@ -27,8 +27,6 @@ import java.util.concurrent.ScheduledExecutorService;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
@@ -47,7 +45,6 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil;
public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
private final FaultTolerantRedisClient pubSubClient;
private final GrpcClientConnectionManager grpcClientConnectionManager;
private final Executor listenerEventExecutor;
private final ScheduledExecutorService retryExecutor;
@@ -74,12 +71,10 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
private record AccountIdentifierAndDeviceId(UUID accountIdentifier, byte deviceId) {}
public DisconnectionRequestManager(final FaultTolerantRedisClient pubSubClient,
final GrpcClientConnectionManager grpcClientConnectionManager,
final Executor listenerEventExecutor,
final ScheduledExecutorService retryExecutor) {
this.pubSubClient = pubSubClient;
this.grpcClientConnectionManager = grpcClientConnectionManager;
this.listenerEventExecutor = listenerEventExecutor;
this.retryExecutor = retryExecutor;
}
@@ -223,8 +218,6 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
}
deviceIds.forEach(deviceId -> {
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(accountIdentifier, deviceId));
listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList())
.forEach(listener -> listenerEventExecutor.execute(() -> {
try {

View File

@@ -1,22 +0,0 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import java.util.Optional;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
throws ChannelNotFoundException {
return grpcClientConnectionManager.getAuthenticatedDevice(call);
}
}

View File

@@ -1,40 +1,30 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
* reject the call with a status of {@code UNAVAILABLE}.
* contain an authorization header in the request metdata. Calls with an associated authenticated device are closed with
* an {@code UNAUTHENTICATED} status.
*/
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public ProhibitAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
super(grpcClientConnectionManager);
}
public class ProhibitAuthenticationInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
try {
return getAuthenticatedDevice(call)
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
// service via an authenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
.orElseGet(() -> next.startCall(call, headers));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
final String authHeaderString = headers.get(Metadata.Key.of(RequireAuthenticationInterceptor.AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER));
if (authHeaderString != null) {
call.close(Status.UNAUTHENTICATED.withDescription("authorization header forbidden"), new Metadata());
return new ServerCall.Listener<>() {};
}
return next.startCall(call, headers);
}
}

View File

@@ -1,43 +1,68 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
/**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
* A "require authentication" interceptor authenticates requests and attaches the {@link AuthenticatedDevice} to the
* current gRPC context. Calls without authentication or with invalid credentials are closed with an
* {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined (i.e. because the accounts
* database is unavailable), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
*/
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public class RequireAuthenticationInterceptor implements ServerInterceptor {
public RequireAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
super(grpcClientConnectionManager);
static final String AUTHORIZATION_HEADER = "authorization";
private final AccountAuthenticator authenticator;
public RequireAuthenticationInterceptor(final AccountAuthenticator authenticator) {
this.authenticator = authenticator;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
final String authHeaderString = headers.get(
Metadata.Key.of(AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER));
try {
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
// problem with the client's request.
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
} catch (final ChannelNotFoundException e) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
if (authHeaderString == null) {
return ServerInterceptorUtil.closeWithStatus(call,
Status.UNAUTHENTICATED.withDescription("missing authorization header"));
}
final Optional<BasicCredentials> basicCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeaderString);
if (basicCredentials.isEmpty()) {
return ServerInterceptorUtil.closeWithStatus(call,
Status.UNAUTHENTICATED.withDescription("malformed authorization header"));
}
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> authenticated =
authenticator.authenticate(basicCredentials.get());
if (authenticated.isEmpty()) {
return ServerInterceptorUtil.closeWithStatus(call,
Status.UNAUTHENTICATED.withDescription("invalid credentials"));
}
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(
authenticated.get().accountIdentifier(),
authenticated.get().deviceId());
return Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next);
}
}

View File

@@ -0,0 +1,15 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.constraints.NotNull;
public record GrpcConfiguration(@NotNull String bindAddress, @NotNull Integer port) {
public GrpcConfiguration {
if (bindAddress == null || bindAddress.isEmpty()) {
bindAddress = "localhost";
}
}
}

View File

@@ -1,25 +0,0 @@
package org.whispersystems.textsecuregcm.configuration;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Positive;
import javax.annotation.Nullable;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPrivateKey;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
public record NoiseTunnelConfiguration(@Positive int webSocketPort,
@Positive int directPort,
@Nullable String tlsKeyStoreFile,
@Nullable String tlsKeyStoreEntryAlias,
@Nullable SecretString tlsKeyStorePassword,
@NotNull SecretBytes noiseStaticPrivateKey,
@NotNull SecretString recognizedProxySecret) {
public ECKeyPair noiseStaticKeyPair() throws InvalidKeyException {
final ECPrivateKey privateKey = new ECPrivateKey(noiseStaticPrivateKey().value());
return new ECKeyPair(privateKey.publicKey(), privateKey);
}
}

View File

@@ -1,55 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
/**
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
*/
public class ChannelShutdownInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
// Don't allow new calls if the connection is getting ready to close
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
@Override
public void onComplete() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onComplete();
}
@Override
public void onCancel() {
grpcClientConnectionManager.handleServerCallComplete(call);
super.onCancel();
}
};
}
}

View File

@@ -1,41 +1,108 @@
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.HttpHeaders;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
* The request attributes interceptor makes common request attributes from call metadata available to service
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
*
* @see RequestAttributesUtil
*/
public class RequestAttributesInterceptor implements ServerInterceptor {
private final GrpcClientConnectionManager grpcClientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
}
private static final Metadata.Key<String> ACCEPT_LANG_KEY =
Metadata.Key.of(HttpHeaders.ACCEPT_LANGUAGE, Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata.Key<String> USER_AGENT_KEY =
Metadata.Key.of(HttpHeaders.USER_AGENT, Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata.Key<String> X_FORWARDED_FOR_KEY =
Metadata.Key.of(HttpHeaders.X_FORWARDED_FOR, Metadata.ASCII_STRING_MARSHALLER);
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final String userAgentHeader = headers.get(USER_AGENT_KEY);
final String acceptLanguageHeader = headers.get(ACCEPT_LANG_KEY);
final String xForwardedForHeader = headers.get(X_FORWARDED_FOR_KEY);
try {
return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
} catch (final ChannelNotFoundException e) {
final Optional<InetAddress> remoteAddress = getMostRecentProxy(xForwardedForHeader)
.flatMap(mostRecentProxy -> {
try {
return Optional.of(InetAddress.ofLiteral(mostRecentProxy));
} catch (IllegalArgumentException e) {
log.warn("Failed to parse most recent proxy {} as an IP address", mostRecentProxy, e);
return Optional.empty();
}
})
.or(() -> {
log.warn("No usable X-Forwarded-For header present, using remote socket address");
final SocketAddress socketAddress = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
if (socketAddress == null || !(socketAddress instanceof InetSocketAddress inetAddress)) {
log.warn("Remote socket address not present or is not an inet address: {}", socketAddress);
return Optional.empty();
}
return Optional.of(inetAddress.getAddress());
});
if (!remoteAddress.isPresent()) {
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
}
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
final RequestAttributes requestAttributes =
new RequestAttributes(remoteAddress.get(), userAgentHeader, acceptLanguages);
return Contexts.interceptCall(
Context.current().withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes),
call, headers, next);
}
/**
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
*
* @param forwardedFor the value of an X-Forwarded-For header
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
* {@code forwardedFor} was null
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a>
*/
public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor)
.map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1;
return idx < forwardedFor.length()
? forwardedFor.substring(idx).trim()
: null;
})
.filter(StringUtils::isNotBlank);
}
}

View File

@@ -1,41 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import javax.crypto.BadPaddingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
/**
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
*/
public class ErrorHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_ERROR,
"Noise encryption error");
@Override
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
e.getMessage());
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
default -> {
log.warn("An unexpected exception reached the end of the pipeline", cause);
yield new OutboundCloseErrorMessage(
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
cause.getMessage());
}
};
context.writeAndFlush(closeMessage)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
}

View File

@@ -1,128 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.util.ReferenceCountUtil;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
/**
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
* server.
*/
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
private final GrpcClientConnectionManager grpcClientConnectionManager;
private final LocalAddress authenticatedGrpcServerAddress;
private final LocalAddress anonymousGrpcServerAddress;
private final FramingType framingType;
private final List<Object> pendingReads = new ArrayList<>();
private static final String CONNECTION_ESTABLISHED_COUNTER_NAME = MetricsUtil.name(EstablishLocalGrpcConnectionHandler.class, "established");
public EstablishLocalGrpcConnectionHandler(final GrpcClientConnectionManager grpcClientConnectionManager,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final FramingType framingType) {
this.grpcClientConnectionManager = grpcClientConnectionManager;
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
this.framingType = framingType;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
pendingReads.add(message);
}
@Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseIdentityDeterminedEvent(
final Optional<AuthenticatedDevice> authenticatedDevice,
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
// authenticated service.
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
? authenticatedGrpcServerAddress
: anonymousGrpcServerAddress;
GrpcClientConnectionManager.handleHandshakeInitiated(
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
final List<Tag> tags = UserAgentTagUtil.getLibsignalAndPlatformTags(userAgent);
Metrics.counter(CONNECTION_ESTABLISHED_COUNTER_NAME, Tags.of(tags)
.and("authenticated", Boolean.toString(authenticatedDevice.isPresent()))
.and("framingType", framingType.name()))
.increment();
new Bootstrap()
.remoteAddress(grpcServerAddress)
.channel(LocalChannel.class)
.group(remoteChannelContext.channel().eventLoop())
.handler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(final LocalChannel localChannel) {
localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel()));
}
})
.connect()
.addListener((ChannelFutureListener) localChannelFuture -> {
if (localChannelFuture.isSuccess()) {
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
authenticatedDevice);
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.channel()
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
remoteChannelContext.pipeline()
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
// Flush any buffered reads we accumulated while waiting to open the connection
pendingReads.forEach(remoteChannelContext::fireChannelRead);
pendingReads.clear();
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
} else {
log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
remoteChannelContext.close();
}
});
}
remoteChannelContext.fireUserEventTriggered(event);
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
pendingReads.forEach(ReferenceCountUtil::release);
pendingReads.clear();
}
}

View File

@@ -1,6 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
public enum FramingType {
NOISE_DIRECT,
WEBSOCKET
}

View File

@@ -1,268 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Grpc;
import io.grpc.ServerCall;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
/**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
* <p>
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
* narrow window between a server call being created and the start of call execution, in which case accessor methods
* in this class will throw a {@link ChannelNotFoundException}.
* <p>
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
* be called from any application code.
*/
public class GrpcClientConnectionManager {
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
@VisibleForTesting
static final AttributeKey<AuthenticatedDevice> AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
@VisibleForTesting
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
private static final OutboundCloseErrorMessage SERVER_CLOSED =
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
/**
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
* (i.e. unauthenticated), the returned value will be empty.
*
* @param serverCall the gRPC server call for which to find an authenticated device
*
* @return the authenticated device associated with the given local address, if any
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
throws ChannelNotFoundException {
return getAuthenticatedDevice(getRemoteChannel(serverCall));
}
@VisibleForTesting
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
}
/**
* Returns the request attributes associated with the given server call.
*
* @param serverCall the gRPC server call for which to retrieve request attributes
*
* @return the request attributes associated with the given server call
*
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
* generally indicates that the channel has closed while request processing is still in progress
*/
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRequestAttributes(getRemoteChannel(serverCall));
}
@VisibleForTesting
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
if (requestAttributes == null) {
throw new IllegalStateException("Channel does not have request attributes");
}
return requestAttributes;
}
/**
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
* given server call.
*
* @param serverCall the server call to start
*
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
* underlying channel is closing
*/
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
try {
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
} catch (final ChannelNotFoundException e) {
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
// should certainly not proceed.
return false;
}
}
/**
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
* associated with the given server call.
*
* @param serverCall the server call to complete
*/
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
try {
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
} catch (final ChannelNotFoundException ignored) {
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
}
}
/**
* Closes any client connections to this host associated with the given authenticated device.
*
* @param authenticatedDevice the authenticated device for which to close connections
*/
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
// Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid
// concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself
// from the list while we're still iterating, resulting in a `ConcurrentModificationException`.
final List<Channel> channelsToClose =
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
}
private static void closeRemoteChannel(final Channel channel) {
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
@VisibleForTesting
@Nullable List<Channel> getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
}
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
return getRemoteChannel(getLocalAddress(serverCall));
}
@VisibleForTesting
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
if (remoteChannel == null) {
throw new ChannelNotFoundException();
}
return remoteChannelsByLocalAddress.get(localAddress);
}
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
// a proxied Noise channel.
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
return localAddress;
}
/**
* Handles receipt of a handshake message and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
*
* @param channel the channel where the handshake was initiated
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
public static void handleHandshakeInitiated(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
channel.attr(REQUEST_ATTRIBUTES_KEY)
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
}
/**
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
*
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
* @param remoteChannel the channel from the remote client to the Noise tunnel
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
*/
void handleConnectionEstablished(final LocalChannel localChannel,
final Channel remoteChannel,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
final List<Channel> channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
channels.add(remoteChannel);
return channels;
}));
remoteChannel.closeFuture().addListener(closeFuture -> {
remoteChannelsByLocalAddress.remove(localChannel.localAddress());
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
if (existingChannelList == null) {
return null;
}
existingChannelList.remove(remoteChannel);
return existingChannelList.isEmpty() ? null : existingChannelList;
}));
});
}
}

View File

@@ -1,32 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has
* either handled a proxy protocol message or determined that no such message is coming.
*/
public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class);
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof HAProxyMessage haProxyMessage) {
// Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still
// need to clear it from the pipeline to avoid confusing the TLS machinery, though.
log.debug("Discarding HAProxy message: {}", haProxyMessage);
haProxyMessage.release();
} else {
super.channelRead(context, message);
}
// Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first
// message, all others will just be "normal" messages, and our work here is done.
context.pipeline().remove(this);
}
}

View File

@@ -1,21 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
public enum HandshakePattern {
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
private final String protocol;
public String protocol() {
return protocol;
}
HandshakePattern(String protocol) {
this.protocol = protocol;
}
}

View File

@@ -1,16 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.dropwizard.lifecycle.Managed;
import io.netty.channel.DefaultEventLoopGroup;
/**
* A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
* Dropwizard to manage the lifecycle of the event loop group.
*/
public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed {
@Override
public void stop() throws InterruptedException {
this.shutdownGracefully().await();
}
}

View File

@@ -0,0 +1,29 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.dropwizard.lifecycle.Managed;
import io.grpc.Server;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
public class ManagedGrpcServer implements Managed {
private final Server server;
public ManagedGrpcServer(Server server) {
this.server = server;
}
@Override
public void start() throws IOException {
server.start();
}
@Override
public void stop() {
try {
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
} catch (final InterruptedException e) {
server.shutdownNow();
}
}
}

View File

@@ -1,49 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.dropwizard.lifecycle.Managed;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
/**
* A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress}
* and whose lifecycle is managed by Dropwizard via the {@link Managed} interface.
*/
public abstract class ManagedLocalGrpcServer implements Managed {
private final Server server;
public ManagedLocalGrpcServer(final LocalAddress localAddress,
final DefaultEventLoopGroup eventLoopGroup) {
final ServerBuilder<?> serverBuilder = NettyServerBuilder.forAddress(localAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup);
configureServer(serverBuilder);
server = serverBuilder.build();
}
protected abstract void configureServer(final ServerBuilder<?> serverBuilder);
@Override
public void start() throws IOException {
server.start();
}
@Override
public void stop() {
try {
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
} catch (final InterruptedException e) {
server.shutdownNow();
}
}
}

View File

@@ -1,13 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
* format or a general encryption error).
*/
public class NoiseException extends NoStackTraceException {
public NoiseException(final String message) {
super(message);
}
}

View File

@@ -1,119 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.PromiseCombiner;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
* messages
*/
public class NoiseHandler extends ChannelDuplexHandler {
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
private final CipherStatePair cipherStatePair;
NoiseHandler(CipherStatePair cipherStatePair) {
this.cipherStatePair = cipherStatePair;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (message instanceof ByteBuf frame) {
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
}
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer.
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
} else {
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
} finally {
ReferenceCountUtil.release(message);
}
}
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
throws ShortBufferException, BadPaddingException {
final CipherState cipherState = cipherStatePair.getReceiver();
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
final int plaintextLength = cipherState.decryptWithAd(null,
frameBytes, 0,
frameBytes, 0,
frameBytes.length);
// Forward the decrypted plaintext along
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
}
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf byteBuf) {
try {
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
final CipherState cipherState = cipherStatePair.getSender();
// Server message might not fit in a single noise packet, break it up into as many chunks as we need
final PromiseCombiner pc = new PromiseCombiner(context.executor());
while (byteBuf.isReadable()) {
final ByteBuf plaintext = byteBuf.readSlice(Math.min(
// need room for a 16-byte AEAD tag
Noise.MAX_PACKET_LEN - 16,
byteBuf.readableBytes()));
final int plaintextLength = plaintext.readableBytes();
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
}
pc.finish(promise);
} finally {
ReferenceCountUtil.release(byteBuf);
}
} else {
if (!(message instanceof OutboundCloseErrorMessage)) {
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
// that get issued in response to exceptions)
log.warn("Unexpected object in pipeline: {}", message);
}
context.write(message, promise);
}
}
@Override
public void handlerRemoved(ChannelHandlerContext var1) {
if (cipherStatePair != null) {
cipherStatePair.destroy();
}
}
}

View File

@@ -1,14 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
/**
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
* a general encryption error).
*/
public class NoiseHandshakeException extends NoStackTraceException {
public NoiseHandshakeException(final String message) {
super(message);
}
}

View File

@@ -1,197 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.net.InetAddress;
import java.security.MessageDigest;
import java.util.Optional;
import java.util.UUID;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
* encrypt/decrypt subsequent data frames
* <p>
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
* currently supports two types of handshakes.
* <p>
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
* deviceId.
* <p>
* The second are NK handshakes which are anonymous.
* <p>
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
* to begin processing the request without an initial message delay (fast open).
* <p>
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
* frames.
*/
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
.build().toByteArray();
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
.build().toByteArray();
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
private boolean receivedHandshakeInit = false;
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
try {
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
if (receivedHandshakeInit) {
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
}
receivedHandshakeInit = true;
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
}
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
// We'll need to copy it to a heap buffer
handleInboundHandshake(context,
handshakeInit.getRemoteAddress(),
handshakeInit.getHandshakePattern(),
ByteBufUtil.getBytes(handshakeInit.content()));
} finally {
ReferenceCountUtil.release(message);
}
}
private void handleInboundHandshake(
final ChannelHandlerContext context,
final InetAddress remoteAddress,
final HandshakePattern handshakePattern,
final byte[] frameBytes) throws NoiseHandshakeException {
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
final ByteBuf payload = handshakeHelper.read(frameBytes);
// Parse the handshake message
final NoiseTunnelProtos.HandshakeInit handshakeInit;
try {
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
} catch (IOException e) {
throw new NoiseHandshakeException("Failed to parse handshake message");
}
switch (handshakePattern) {
case NK -> {
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
}
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
}
case IK -> {
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
.orElseThrow(() -> new IllegalStateException("No remote public key"));
final UUID accountIdentifier = aci(handshakeInit);
final byte deviceId = deviceId(handshakeInit);
clientPublicKeysManager
.findPublicKey(accountIdentifier, deviceId)
.whenCompleteAsync((storedPublicKey, throwable) -> {
if (throwable != null) {
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
return;
}
final boolean valid = storedPublicKey
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
.orElse(false);
if (!valid) {
// Write a handshake response indicating that the client used the wrong public key
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
return;
}
handleAuthenticated(context,
handshakeHelper, remoteAddress, handshakeInit,
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
}, context.executor());
}
};
}
private void handleAuthenticated(final ChannelHandlerContext context,
final NoiseHandshakeHelper handshakeHelper,
final InetAddress remoteAddress,
final NoiseTunnelProtos.HandshakeInit handshakeInit,
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
maybeAuthenticatedDevice,
remoteAddress,
handshakeInit.getUserAgent(),
handshakeInit.getAcceptLanguage()));
// Now that we've authenticated, write the handshake response
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
// decrypted the fastOpen request
context.pipeline()
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
context.pipeline().remove(NoiseHandshakeHandler.class);
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
// encrypt the response when the server writes back through us
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
}
}
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
try {
return UUIDUtil.fromByteString(handshakePayload.getAci());
} catch (IllegalArgumentException e) {
throw new NoiseHandshakeException("Could not parse aci");
}
}
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
throw new NoiseHandshakeException("Invalid deviceId");
}
return (byte) handshakePayload.getDeviceId();
}
}

View File

@@ -1,124 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
import com.southernstorm.noise.protocol.HandshakeState;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.security.NoSuchAlgorithmException;
import java.util.Optional;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
/**
* Helper for the responder of a 2-message handshake with a pre-shared responder static key
*/
class NoiseHandshakeHelper {
private final static int AEAD_TAG_LENGTH = 16;
private final static int KEY_LENGTH = 32;
private final HandshakePattern handshakePattern;
private final ECKeyPair serverStaticKeyPair;
private final HandshakeState handshakeState;
NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) {
this.handshakePattern = handshakePattern;
this.serverStaticKeyPair = serverStaticKeyPair;
try {
this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e);
}
}
/**
* Get the length of the initiator's keys
*
* @return length of the handshake message sent by the remote party (the initiator) not including the payload
*/
private int initiatorHandshakeMessageKeyLength() {
return switch (handshakePattern) {
// ephemeral key, static key (encrypted), AEAD tag for static key
case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH;
// ephemeral key only
case NK -> KEY_LENGTH;
};
}
HandshakeState getHandshakeState() {
return this.handshakeState;
}
ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException {
if (handshakeState.getAction() != HandshakeState.NO_ACTION) {
throw new NoiseHandshakeException("Cannot send more data before handshake is complete");
}
// Length for an empty payload
final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH;
if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) {
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
}
final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH;
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
// key), so we just set the private key here.
handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0);
handshakeState.start();
int payloadBytesRead;
try {
payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length,
remoteHandshakeMessage, 0);
} catch (final ShortBufferException e) {
// This should never happen since we're checking the length of the frame up front
throw new NoiseHandshakeException("Unexpected client payload");
} catch (final BadPaddingException e) {
// We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key
// or payload
throw new NoiseHandshakeException("Invalid keys or payload");
}
if (payloadBytesRead != payloadLength) {
throw new NoiseHandshakeException(
"Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead);
}
return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead);
}
byte[] write(byte[] payload) {
if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) {
throw new IllegalStateException("Cannot send data before handshake is complete");
}
// Currently only support handshake patterns where the server static key is known
// Send our ephemeral key and the response to the initiator with the encrypted payload
final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH];
try {
int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length);
if (written != response.length) {
throw new IllegalStateException("Unexpected handshake response length");
}
return response;
} catch (final ShortBufferException e) {
// This should never happen for messages of known length that we control
throw new IllegalStateException("Key material buffer was too short for message", e);
}
}
Optional<byte[]> remotePublicKey() {
return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> {
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
return publicKeyFromClient;
});
}
}

View File

@@ -1,33 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
import java.net.InetAddress;
/**
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
*/
public class NoiseHandshakeInit extends DefaultByteBufHolder {
private final InetAddress remoteAddress;
private final HandshakePattern handshakePattern;
public NoiseHandshakeInit(
final InetAddress remoteAddress,
final HandshakePattern handshakePattern,
final ByteBuf initiatorHandshakeMessage) {
super(initiatorHandshakeMessage);
this.remoteAddress = remoteAddress;
this.handshakePattern = handshakePattern;
}
public InetAddress getRemoteAddress() {
return remoteAddress;
}
public HandshakePattern getHandshakePattern() {
return handshakePattern;
}
}

View File

@@ -1,22 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import java.net.InetAddress;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
/**
* An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is
* connecting anonymously, the identity is empty, otherwise it will be present and already authenticated.
*
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
* type that performs authentication
* @param remoteAddress the remote address of the connecting client
* @param userAgent the client supplied userAgent
* @param acceptLanguage the client supplied acceptLanguage
*/
public record NoiseIdentityDeterminedEvent(
Optional<AuthenticatedDevice> authenticatedDevice,
InetAddress remoteAddress,
String userAgent,
String acceptLanguage) {
}

View File

@@ -1,31 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net;
/**
* An error written to the outbound pipeline that indicates the connection should be closed
*/
public record OutboundCloseErrorMessage(Code code, String message) {
public enum Code {
/**
* The server decided to close the connection. This could be because the server is going away, or it could be
* because the credentials for the connected client have been updated.
*/
SERVER_CLOSED,
/**
* There was a noise decryption error after the noise session was established
*/
NOISE_ERROR,
/**
* There was an error establishing the noise handshake
*/
NOISE_HANDSHAKE_ERROR,
INTERNAL_SERVER_ERROR
}
}

View File

@@ -1,24 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
/**
* A proxy handler writes all data read from one channel to another peer channel.
*/
public class ProxyHandler extends ChannelInboundHandlerAdapter {
private final Channel peerChannel;
public ProxyHandler(final Channel peerChannel) {
this.peerChannel = peerChannel;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) {
peerChannel.writeAndFlush(message)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
}
}

View File

@@ -1,73 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
/**
* A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection.
* If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline.
* In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol
* message, it will remove itself from the pipeline and pass any intercepted down the pipeline.
*
* @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">The PROXY protocol</a>
*/
public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter {
private CompositeByteBuf accumulator;
@VisibleForTesting
static final int PROXY_MESSAGE_DETECTION_BYTES = 12;
@Override
public void handlerAdded(final ChannelHandlerContext context) {
// We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get
// non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every
// practical case, though, we'll be able to tell from the first packet.
accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof ByteBuf byteBuf) {
accumulator.addComponent(true, byteBuf);
switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) {
case NEEDS_MORE_DATA -> {
}
case INVALID -> {
// We have enough information to determine that this connection is NOT starting with a proxy protocol message,
// and we can just pass the accumulated bytes through
context.fireChannelRead(accumulator);
accumulator = null;
context.pipeline().remove(this);
}
case DETECTED -> {
// We have enough information to know that we're dealing with a proxy protocol message; add appropriate
// handlers and pass the accumulated bytes through
context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder());
context.fireChannelRead(accumulator);
accumulator = null;
context.pipeline().remove(this);
}
}
} else {
super.channelRead(context, message);
}
}
@Override
public void handlerRemoved(final ChannelHandlerContext context) {
if (accumulator != null) {
accumulator.release();
accumulator = null;
}
}
}

View File

@@ -1,45 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
/**
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
* <p>
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
* handshake response, and then the subsequent messages are all data frame payloads.
*/
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
ReferenceCountUtil.release(msg);
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
}
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -1,72 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
public class NoiseDirectFrame extends DefaultByteBufHolder {
static final byte VERSION = 0x00;
private final FrameType frameType;
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
super(data);
this.frameType = frameType;
}
public FrameType frameType() {
return frameType;
}
public byte versionedFrameTypeByte() {
final byte frameBits = frameType().getFrameBits();
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
}
public enum FrameType {
/**
* The payload is the initiator message for a Noise NK handshake. If established, the
* session will be unauthenticated.
*/
NK_HANDSHAKE((byte) 1),
/**
* The payload is the initiator message for a Noise IK handshake. If established, the
* session will be authenticated.
*/
IK_HANDSHAKE((byte) 2),
/**
* The payload is an encrypted noise packet.
*/
DATA((byte) 3),
/**
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
* closed.
*/
CLOSE((byte) 4);
private final byte frameType;
FrameType(byte frameType) {
if (frameType != (0x0F & frameType)) {
throw new IllegalStateException("Frame type must fit in 4 bits");
}
this.frameType = frameType;
}
public byte getFrameBits() {
return frameType;
}
public boolean isHandshake() {
return switch (this) {
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
case DATA, CLOSE -> false;
};
}
}
}

View File

@@ -1,90 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
/**
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
*/
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof ByteBuf byteBuf) {
try {
ctx.fireChannelRead(deserialize(byteBuf));
} catch (Exception e) {
ReferenceCountUtil.release(byteBuf);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
try {
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
ctx.writeAndFlush(serialized, promise);
} finally {
ReferenceCountUtil.release(noiseDirectFrame);
}
} else {
ctx.write(msg, promise);
}
}
private ByteBuf serialize(
final ChannelHandlerContext ctx,
final NoiseDirectFrame noiseDirectFrame) {
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
}
// 1 version/frametype byte, 2 length bytes, content
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
byteBuf.writeBytes(noiseDirectFrame.content());
return byteBuf;
}
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
final byte versionAndFrameByte = byteBuf.readByte();
final int version = (versionAndFrameByte & 0xF0) >> 4;
if (version != NoiseDirectFrame.VERSION) {
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
}
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
case 3 -> NoiseDirectFrame.FrameType.DATA;
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
};
final int length = Short.toUnsignedInt(byteBuf.readShort());
if (length != byteBuf.readableBytes()) {
throw new IllegalArgumentException(
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
}
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
}
}

View File

@@ -1,53 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.net.InetSocketAddress;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
/**
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
*/
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame frame) {
try {
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
throw new IOException("Could not determine remote address");
}
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
// needs into a NoiseHandshakeInit message and forward that along
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
switch (frame.frameType()) {
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
case IK_HANDSHAKE -> HandshakePattern.IK;
case NK_HANDSHAKE -> HandshakePattern.NK;
}, frame.content());
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
// for network serialization. Note that we need to install the Data frame handler before firing the read,
// because we may receive an outbound message from the noiseHandler
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
ctx.fireChannelRead(handshakeMessage);
} catch (Exception e) {
ReferenceCountUtil.release(msg);
throw e;
}
} else {
ctx.fireChannelRead(msg);
}
}
}

View File

@@ -1,36 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.micrometer.core.instrument.Metrics;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* Watches for inbound close frames and closes the connection in response
*/
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "clientClose");
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
try {
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
} finally {
ReferenceCountUtil.release(msg);
ctx.close();
}
} else {
ctx.fireChannelRead(msg);
}
}
}

View File

@@ -1,43 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import io.micrometer.core.instrument.Metrics;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
* written, the channel is closed
*/
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "serverClose");
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
};
Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "reason", code.name()).increment();
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
.setCode(code)
.setMessage(err.message())
.build();
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
proto.writeTo(new ByteBufOutputStream(byteBuf));
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
.addListener(ChannelFutureListener.CLOSE);
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -1,96 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
import io.dropwizard.lifecycle.Managed;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import java.net.InetSocketAddress;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
* binary framing protocol) and passes it through to a local gRPC server.
*/
public class NoiseDirectTunnelServer implements Managed {
private final ServerBootstrap bootstrap;
private ServerSocketChannel channel;
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
public NoiseDirectTunnelServer(final int port,
final NioEventLoopGroup eventLoopGroup,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) {
this.bootstrap = new ServerBootstrap()
.group(eventLoopGroup)
.channel(NioServerSocketChannel.class)
.localAddress(port)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) {
socketChannel.pipeline()
.addLast(new ProxyProtocolDetectionHandler())
.addLast(new HAProxyMessageHandler())
// frame byte followed by a 2-byte length field
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
// Parses NoiseDirectFrames from wire bytes and vice versa
.addLast(new NoiseDirectFrameCodec())
// Terminate the connection if the client sends us a close frame
.addLast(new NoiseDirectInboundCloseHandler())
// Turn generic OutboundCloseErrorMessages into noise direct error frames
.addLast(new NoiseDirectOutboundErrorHandler())
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
// NoiseDirectDataFrameCodec to handle subsequent data frames
.addLast(new NoiseDirectHandshakeSelector())
// Performs the noise handshake and then replace itself with a NoiseHandler
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(
grpcClientConnectionManager,
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
FramingType.NOISE_DIRECT))
.addLast(new ErrorHandler());
}
});
}
@VisibleForTesting
public InetSocketAddress getLocalAddress() {
return channel.localAddress();
}
@Override
public void start() throws InterruptedException {
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
}
@Override
public void stop() throws InterruptedException {
if (channel != null) {
channel.close().await();
}
}
}

View File

@@ -1,18 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
enum ApplicationWebSocketCloseReason {
NOISE_HANDSHAKE_ERROR(4001),
NOISE_ENCRYPTION_ERROR(4002);
private final int statusCode;
ApplicationWebSocketCloseReason(final int statusCode) {
this.statusCode = statusCode;
}
public int getStatusCode() {
return statusCode;
}
}

View File

@@ -1,146 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.southernstorm.noise.protocol.Noise;
import io.dropwizard.lifecycle.Managed;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.OpenSsl;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProtocols;
import io.netty.handler.ssl.SslProvider;
import java.net.InetSocketAddress;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.*;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A Noise-over-WebSocket tunnel server accepts traffic from the public internet (in the form of Noise packets framed by
* binary WebSocket frames) and passes it through to a local gRPC server.
*/
public class NoiseWebSocketTunnelServer implements Managed {
private final ServerBootstrap bootstrap;
private ServerSocketChannel channel;
static final String AUTHENTICATED_SERVICE_PATH = "/authenticated";
static final String ANONYMOUS_SERVICE_PATH = "/anonymous";
static final String HEALTH_CHECK_PATH = "/health-check";
private static final Logger log = LoggerFactory.getLogger(NoiseWebSocketTunnelServer.class);
public NoiseWebSocketTunnelServer(final int websocketPort,
@Nullable final X509Certificate[] tlsCertificateChain,
@Nullable final PrivateKey tlsPrivateKey,
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final GrpcClientConnectionManager grpcClientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws SSLException {
@Nullable final SslContext sslContext;
if (tlsCertificateChain != null && tlsPrivateKey != null) {
final SslProvider sslProvider;
if (OpenSsl.isAvailable()) {
log.info("Native OpenSSL provider is available; will use native provider");
sslProvider = SslProvider.OPENSSL;
} else {
log.info("No native SSL provider available; will use JDK provider");
sslProvider = SslProvider.JDK;
}
sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
.clientAuth(ClientAuth.NONE)
// Some load balancers require TLS 1.2 for health checks
.protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2)
.sslProvider(sslProvider)
.build();
} else {
log.warn("No TLS credentials provided; Noise-over-WebSocket tunnel will not use TLS. This configuration is not suitable for production environments.");
sslContext = null;
}
this.bootstrap = new ServerBootstrap()
.group(eventLoopGroup)
.channel(NioServerSocketChannel.class)
.localAddress(websocketPort)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel socketChannel) {
socketChannel.pipeline()
.addLast(new ProxyProtocolDetectionHandler())
.addLast(new HAProxyMessageHandler());
if (sslContext != null) {
socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
}
socketChannel.pipeline()
.addLast(new HttpServerCodec())
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
// The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade
// request and passed it down the pipeline
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
.addLast(new WebSocketServerProtocolHandler("/", true))
// Metrics on inbound/outbound Close frames
.addLast(new WebSocketCloseMetricHandler())
// Turn generic OutboundCloseErrorMessages into websocket close frames
.addLast(new WebSocketOutboundErrorHandler())
.addLast(new RejectUnsupportedMessagesHandler())
.addLast(new WebsocketPayloadCodec())
// The WebSocket handshake complete listener will forward the first payload supplemented with
// data from the websocket handshake completion event, and then remove itself from the pipeline
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
// NoiseHandler
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(
grpcClientConnectionManager,
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
FramingType.WEBSOCKET))
.addLast(new ErrorHandler());
}
});
}
@VisibleForTesting
InetSocketAddress getLocalAddress() {
return channel.localAddress();
}
@Override
public void start() throws InterruptedException {
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
}
@Override
public void stop() throws InterruptedException {
if (channel != null) {
channel.close().await();
}
}
}

View File

@@ -1,35 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
/**
* A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process.
*/
public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof final WebSocketFrame webSocketFrame) {
if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) {
try {
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
} finally {
textWebSocketFrame.release();
}
} else {
// Allow all other types of WebSocket frames
context.fireChannelRead(webSocketFrame);
}
} else {
// Discard anything that's not a WebSocket frame
ReferenceCountUtil.release(message);
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
}
}
}

View File

@@ -1,50 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.micrometer.core.instrument.Metrics;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
public class WebSocketCloseMetricHandler extends ChannelDuplexHandler {
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "clientClose");
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "serverClose");
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof CloseWebSocketFrame closeFrame) {
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "closeCode", validatedCloseCode(closeFrame.statusCode())).increment();
}
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof CloseWebSocketFrame closeFrame) {
Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "closeCode", Integer.toString(closeFrame.statusCode())).increment();
}
ctx.write(msg, promise);
}
private static String validatedCloseCode(int closeCode) {
if (closeCode >= 1000 && closeCode <= 1015) {
// RFC-6455 pre-defined status codes
return Integer.toString(closeCode);
} else if (closeCode >= 4000 && closeCode <= 4100) {
// Application status codes
return Integer.toString(closeCode);
} else {
return "unknown";
}
}
}

View File

@@ -1,81 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.ReferenceCountUtil;
/**
* A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully
* rejects requests for anything other than a WebSocket connection to a known endpoint.
*/
class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter {
private final String authenticatedPath;
private final String anonymousPath;
private final String healthCheckPath;
WebSocketOpeningHandshakeHandler(final String authenticatedPath,
final String anonymousPath,
final String healthCheckPath) {
this.authenticatedPath = authenticatedPath;
this.anonymousPath = anonymousPath;
this.healthCheckPath = healthCheckPath;
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
if (message instanceof FullHttpRequest request) {
boolean shouldReleaseRequest = true;
try {
if (request.decoderResult().isSuccess()) {
if (HttpMethod.GET.equals(request.method())) {
if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) {
if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
// Pass the request along to the websocket handshake handler and remove ourselves from the pipeline
shouldReleaseRequest = false;
context.fireChannelRead(request);
context.pipeline().remove(this);
} else {
closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED);
}
} else if (healthCheckPath.equals(request.uri())) {
closeConnectionWithStatus(context, request, HttpResponseStatus.NO_CONTENT);
} else {
closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND);
}
} else {
closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED);
}
} else {
closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST);
}
} finally {
if (shouldReleaseRequest) {
request.release();
}
}
} else {
// Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error
ReferenceCountUtil.release(message);
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
}
}
private static void closeConnectionWithStatus(final ChannelHandlerContext context,
final FullHttpRequest request,
final HttpResponseStatus status) {
context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status))
.addListener(ChannelFutureListener.CLOSE);
}
}

View File

@@ -1,60 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
/**
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
*/
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketOutboundErrorHandler.class, "serverClose");
private boolean websocketHandshakeComplete = false;
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
setWebsocketHandshakeComplete();
}
context.fireUserEventTriggered(event);
}
protected void setWebsocketHandshakeComplete() {
this.websocketHandshakeComplete = true;
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof OutboundCloseErrorMessage err) {
if (websocketHandshakeComplete) {
final int status = switch (err.code()) {
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
};
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
} else {
log.debug("Error {} occurred before websocket handshake complete", err);
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
// way; just close the connection instead.
ctx.close();
}
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -1,152 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.ReferenceCountUtil;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
/**
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
* Noise handshake handler for the requested path.
*/
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
private final byte[] recognizedProxySecret;
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
@VisibleForTesting
static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
@VisibleForTesting
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
private InetAddress remoteAddress = null;
private HandshakePattern handshakePattern = null;
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
// MessageDigest.equals() later.
this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final Optional<InetAddress> maybePreferredRemoteAddress =
getPreferredRemoteAddress(context, handshakeCompleteEvent);
if (maybePreferredRemoteAddress.isEmpty()) {
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
"Could not determine remote address"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
return;
}
remoteAddress = maybePreferredRemoteAddress.get();
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
// internal error if something slipped through.
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
};
}
context.fireUserEventTriggered(event);
}
@Override
public void channelRead(final ChannelHandlerContext context, final Object msg) {
try {
if (!(msg instanceof ByteBuf frame)) {
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
}
if (handshakePattern == null || remoteAddress == null) {
throw new IllegalStateException("Received payload before websocket handshake complete");
}
final NoiseHandshakeInit handshakeMessage =
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
context.fireChannelRead(handshakeMessage);
} catch (Exception e) {
ReferenceCountUtil.release(msg);
throw e;
}
}
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final byte[] recognizedProxySecretFromHeader =
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
.getBytes(StandardCharsets.UTF_8);
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
return getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
try {
return InetAddresses.forString(mostRecentProxy);
} catch (final IllegalArgumentException e) {
log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
return null;
}
});
} else {
// Either we don't trust the forwarded-for header or it's not present
if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
return Optional.of(inetSocketAddress.getAddress());
} else {
log.warn("Channel's remote address was not an InetSocketAddress");
return Optional.empty();
}
}
}
/**
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
*
* @param forwardedFor the value of an X-Forwarded-For header
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
* {@code forwardedFor} was null
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a>
*/
@VisibleForTesting
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor)
.map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1;
return idx < forwardedFor.length()
? forwardedFor.substring(idx).trim()
: null;
})
.filter(StringUtils::isNotBlank);
}
}

View File

@@ -1,38 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc.net.websocket;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
/**
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
*/
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
if (msg instanceof BinaryWebSocketFrame frame) {
ctx.fireChannelRead(frame.content());
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
if (msg instanceof ByteBuf bb) {
ctx.write(new BinaryWebSocketFrame(bb), promise);
} else {
ctx.write(msg, promise);
}
}
}

View File

@@ -15,11 +15,7 @@ import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.time.Clock;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
@@ -43,7 +39,6 @@ import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
@@ -267,9 +262,8 @@ record CommandDependencies(
() -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, retryExecutor, configuration.getSecureStorageServiceConfiguration());
GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient,
grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor);
disconnectionRequestListenerExecutor, retryExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, retryExecutor, Clock.systemUTC(), experimentEnrollmentManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client,