mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-18 14:15:17 +01:00
Add direct grpc server
This commit is contained in:
5
pom.xml
5
pom.xml
@@ -299,11 +299,6 @@
|
||||
<artifactId>simple-grpc-runtime</artifactId>
|
||||
<version>${simple-grpc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.signal.forks</groupId>
|
||||
<artifactId>noise-java</artifactId>
|
||||
<version>0.1.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-bom</artifactId>
|
||||
|
||||
@@ -102,7 +102,3 @@ turn.cloudflare.apiToken: ABCDEFGHIJKLM
|
||||
linkDevice.secret: AAAAAAAAAAA=
|
||||
|
||||
tlsKeyStore.password: unset
|
||||
|
||||
noiseTunnel.tlsKeyStorePassword: ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
||||
noiseTunnel.noiseStaticPrivateKey: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
|
||||
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA
|
||||
|
||||
@@ -529,15 +529,6 @@ turn:
|
||||
linkDevice:
|
||||
secret: secret://linkDevice.secret
|
||||
|
||||
noiseTunnel:
|
||||
webSocketPort: 8444
|
||||
directPort: 8445
|
||||
tlsKeyStoreFile: /path/to/file.p12
|
||||
tlsKeyStoreEntryAlias: example.com
|
||||
tlsKeyStorePassword: secret://noiseTunnel.tlsKeyStorePassword
|
||||
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
|
||||
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
|
||||
|
||||
externalRequestFilter:
|
||||
grpcMethods:
|
||||
- com.example.grpc.ExampleService/exampleMethod
|
||||
@@ -548,3 +539,6 @@ externalRequestFilter:
|
||||
|
||||
idlePrimaryDeviceReminder:
|
||||
minIdleDuration: P30D
|
||||
|
||||
grpc:
|
||||
port: 50051
|
||||
|
||||
@@ -82,11 +82,6 @@
|
||||
<artifactId>libsignal-server</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.signal.forks</groupId>
|
||||
<artifactId>noise-java</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.signal</groupId>
|
||||
<artifactId>simple-grpc-runtime</artifactId>
|
||||
|
||||
@@ -39,13 +39,13 @@ import org.whispersystems.textsecuregcm.configuration.FcmConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.GenericZkConfig;
|
||||
import org.whispersystems.textsecuregcm.configuration.GooglePlayBillingConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.GrpcConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.IdlePrimaryDeviceReminderConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.KeyTransparencyServiceConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.LinkDeviceSecretConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.NoiseTunnelConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.OpenTelemetryConfiguration;
|
||||
import org.whispersystems.textsecuregcm.configuration.PagedSingleUseKEMPreKeyStoreConfiguration;
|
||||
@@ -315,11 +315,6 @@ public class WhisperServerConfiguration extends Configuration {
|
||||
@JsonProperty
|
||||
private VirtualThreadConfiguration virtualThread = new VirtualThreadConfiguration();
|
||||
|
||||
@Valid
|
||||
@NotNull
|
||||
@JsonProperty
|
||||
private NoiseTunnelConfiguration noiseTunnel;
|
||||
|
||||
@Valid
|
||||
@NotNull
|
||||
@JsonProperty
|
||||
@@ -348,6 +343,11 @@ public class WhisperServerConfiguration extends Configuration {
|
||||
@NotNull
|
||||
private RetryConfiguration generalRedisRetry = new RetryConfiguration();
|
||||
|
||||
@NotNull
|
||||
@Valid
|
||||
@JsonProperty
|
||||
private GrpcConfiguration grpc;
|
||||
|
||||
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
|
||||
return tlsKeyStore;
|
||||
}
|
||||
@@ -551,10 +551,6 @@ public class WhisperServerConfiguration extends Configuration {
|
||||
return virtualThread;
|
||||
}
|
||||
|
||||
public NoiseTunnelConfiguration getNoiseTunnelConfiguration() {
|
||||
return noiseTunnel;
|
||||
}
|
||||
|
||||
public ExternalRequestFilterConfiguration getExternalRequestFilterConfiguration() {
|
||||
return externalRequestFilter;
|
||||
}
|
||||
@@ -582,4 +578,8 @@ public class WhisperServerConfiguration extends Configuration {
|
||||
public RetryConfiguration getGeneralRedisRetryConfiguration() {
|
||||
return generalRedisRetry;
|
||||
}
|
||||
|
||||
public GrpcConfiguration getGrpc() {
|
||||
return grpc;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,12 +23,14 @@ import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.jetty.HttpsConnectorFactory;
|
||||
import io.dropwizard.lifecycle.setup.LifecycleEnvironment;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.ServerInterceptors;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
|
||||
import io.lettuce.core.metrics.MicrometerOptions;
|
||||
import io.lettuce.core.resource.ClientResources;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.socket.nio.NioDatagramChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.resolver.ResolvedAddressTypes;
|
||||
@@ -38,16 +40,12 @@ import jakarta.servlet.DispatcherType;
|
||||
import jakarta.servlet.Filter;
|
||||
import jakarta.servlet.ServletRegistration;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.http.HttpClient;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.KeyStore;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
@@ -62,7 +60,6 @@ import java.util.concurrent.SynchronousQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.eclipse.jetty.websocket.core.WebSocketComponents;
|
||||
import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents;
|
||||
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
@@ -154,12 +151,8 @@ import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.ValidatingInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedGrpcServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ManagedNioEventLoopGroup;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.NoiseWebSocketTunnelServer;
|
||||
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
|
||||
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
|
||||
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
|
||||
@@ -259,9 +252,9 @@ import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager;
|
||||
import org.whispersystems.textsecuregcm.subscriptions.GooglePlayBillingManager;
|
||||
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
|
||||
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
|
||||
import org.whispersystems.textsecuregcm.util.ResilienceUtil;
|
||||
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
|
||||
import org.whispersystems.textsecuregcm.util.ManagedExecutors;
|
||||
import org.whispersystems.textsecuregcm.util.ResilienceUtil;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
||||
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
|
||||
@@ -636,9 +629,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||
() -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion());
|
||||
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
|
||||
storageServiceExecutor, retryExecutor, config.getSecureStorageServiceConfiguration());
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
|
||||
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient,
|
||||
grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor);
|
||||
disconnectionRequestListenerExecutor, retryExecutor);
|
||||
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client,
|
||||
config.getCdnConfiguration().bucket());
|
||||
MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler,
|
||||
@@ -828,127 +820,63 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||
config.getAppleDeviceCheck().teamId(),
|
||||
config.getAppleDeviceCheck().bundleId());
|
||||
|
||||
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
|
||||
|
||||
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
|
||||
final MetricServerInterceptor metricServerInterceptor = new MetricServerInterceptor(Metrics.globalRegistry, clientReleaseManager);
|
||||
|
||||
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
|
||||
final RequestAttributesInterceptor requestAttributesInterceptor =
|
||||
new RequestAttributesInterceptor(grpcClientConnectionManager);
|
||||
final RequestAttributesInterceptor requestAttributesInterceptor = new RequestAttributesInterceptor();
|
||||
|
||||
final ValidatingInterceptor validatingInterceptor = new ValidatingInterceptor();
|
||||
|
||||
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
|
||||
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
|
||||
final ExternalRequestFilter grpcExternalRequestFilter = new ExternalRequestFilter(
|
||||
config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
|
||||
config.getExternalRequestFilterConfiguration().grpcMethods());
|
||||
final RequireAuthenticationInterceptor requireAuthenticationInterceptor = new RequireAuthenticationInterceptor(accountAuthenticator);
|
||||
final ProhibitAuthenticationInterceptor prohibitAuthenticationInterceptor = new ProhibitAuthenticationInterceptor();
|
||||
|
||||
final ManagedLocalGrpcServer anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, localEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
final List<ServerServiceDefinition> authenticatedServices = Stream.of(
|
||||
new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager),
|
||||
ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters),
|
||||
new KeysGrpcService(accountsManager, keysManager, rateLimiters),
|
||||
new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
|
||||
config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations))
|
||||
.map(bindableService -> ServerInterceptors.intercept(bindableService,
|
||||
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||
// depends on the user-agent context so it has to come first here!
|
||||
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
||||
serverBuilder
|
||||
.intercept(
|
||||
new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
|
||||
config.getExternalRequestFilterConfiguration().grpcMethods()))
|
||||
.intercept(validatingInterceptor)
|
||||
.intercept(metricServerInterceptor)
|
||||
.intercept(errorMappingInterceptor)
|
||||
.intercept(remoteDeprecationFilter)
|
||||
.intercept(requestAttributesInterceptor)
|
||||
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager))
|
||||
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
|
||||
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC()))
|
||||
.addService(new PaymentsGrpcService(currencyManager))
|
||||
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
|
||||
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams));
|
||||
}
|
||||
};
|
||||
validatingInterceptor,
|
||||
metricServerInterceptor,
|
||||
errorMappingInterceptor,
|
||||
remoteDeprecationFilter,
|
||||
requestAttributesInterceptor,
|
||||
requireAuthenticationInterceptor))
|
||||
.toList();
|
||||
|
||||
final ManagedLocalGrpcServer authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, localEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
final List<ServerServiceDefinition> unauthenticatedServices = Stream.of(
|
||||
new AccountsAnonymousGrpcService(accountsManager, rateLimiters),
|
||||
new KeysAnonymousGrpcService(accountsManager, keysManager, zkSecretParams, Clock.systemUTC()),
|
||||
new PaymentsGrpcService(currencyManager),
|
||||
ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config),
|
||||
new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkSecretParams))
|
||||
.map(bindableService -> ServerInterceptors.intercept(bindableService,
|
||||
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||
// depends on the user-agent context so it has to come first here!
|
||||
// http://grpc.github.io/grpc-java/javadoc/io/grpc/ServerBuilder.html#intercept-io.grpc.ServerInterceptor-
|
||||
serverBuilder
|
||||
.intercept(validatingInterceptor)
|
||||
.intercept(metricServerInterceptor)
|
||||
.intercept(errorMappingInterceptor)
|
||||
.intercept(remoteDeprecationFilter)
|
||||
.intercept(requestAttributesInterceptor)
|
||||
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager))
|
||||
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
|
||||
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
|
||||
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))
|
||||
.addService(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
|
||||
config.getBadges(), profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations));
|
||||
}
|
||||
};
|
||||
grpcExternalRequestFilter,
|
||||
validatingInterceptor,
|
||||
metricServerInterceptor,
|
||||
errorMappingInterceptor,
|
||||
remoteDeprecationFilter,
|
||||
requestAttributesInterceptor,
|
||||
prohibitAuthenticationInterceptor))
|
||||
.toList();
|
||||
|
||||
@Nullable final X509Certificate[] noiseWebSocketTlsCertificateChain;
|
||||
@Nullable final PrivateKey noiseWebSocketTlsPrivateKey;
|
||||
final ServerBuilder<?> serverBuilder =
|
||||
NettyServerBuilder.forAddress(new InetSocketAddress(config.getGrpc().bindAddress(), config.getGrpc().port()));
|
||||
authenticatedServices.forEach(serverBuilder::addService);
|
||||
unauthenticatedServices.forEach(serverBuilder::addService);
|
||||
final ManagedGrpcServer exposedGrpcServer = new ManagedGrpcServer(serverBuilder.build());
|
||||
|
||||
if (config.getNoiseTunnelConfiguration().tlsKeyStoreFile() != null &&
|
||||
config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias() != null &&
|
||||
config.getNoiseTunnelConfiguration().tlsKeyStorePassword() != null) {
|
||||
|
||||
try (final FileInputStream websocketNoiseTunnelTlsKeyStoreInputStream = new FileInputStream(config.getNoiseTunnelConfiguration().tlsKeyStoreFile())) {
|
||||
final KeyStore keyStore = KeyStore.getInstance("PKCS12");
|
||||
keyStore.load(websocketNoiseTunnelTlsKeyStoreInputStream, config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray());
|
||||
|
||||
final KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) keyStore.getEntry(
|
||||
config.getNoiseTunnelConfiguration().tlsKeyStoreEntryAlias(),
|
||||
new KeyStore.PasswordProtection(config.getNoiseTunnelConfiguration().tlsKeyStorePassword().value().toCharArray()));
|
||||
|
||||
noiseWebSocketTlsCertificateChain =
|
||||
Arrays.copyOf(privateKeyEntry.getCertificateChain(), privateKeyEntry.getCertificateChain().length, X509Certificate[].class);
|
||||
|
||||
noiseWebSocketTlsPrivateKey = privateKeyEntry.getPrivateKey();
|
||||
}
|
||||
} else {
|
||||
noiseWebSocketTlsCertificateChain = null;
|
||||
noiseWebSocketTlsPrivateKey = null;
|
||||
}
|
||||
|
||||
final ExecutorService noiseWebSocketDelegatedTaskExecutor = ExecutorServiceBuilder.of(environment, "noiseWebsocketDelegatedTask")
|
||||
.minThreads(8)
|
||||
.maxThreads(8)
|
||||
.allowCoreThreadTimeOut(false)
|
||||
.build();
|
||||
|
||||
final ManagedNioEventLoopGroup noiseTunnelEventLoopGroup = new ManagedNioEventLoopGroup();
|
||||
|
||||
final NoiseWebSocketTunnelServer noiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(
|
||||
config.getNoiseTunnelConfiguration().webSocketPort(),
|
||||
noiseWebSocketTlsCertificateChain,
|
||||
noiseWebSocketTlsPrivateKey,
|
||||
noiseTunnelEventLoopGroup,
|
||||
noiseWebSocketDelegatedTaskExecutor,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
config.getNoiseTunnelConfiguration().noiseStaticKeyPair(),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
config.getNoiseTunnelConfiguration().recognizedProxySecret().value());
|
||||
|
||||
final NoiseDirectTunnelServer noiseDirectTunnelServer = new NoiseDirectTunnelServer(
|
||||
config.getNoiseTunnelConfiguration().directPort(),
|
||||
noiseTunnelEventLoopGroup,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
config.getNoiseTunnelConfiguration().noiseStaticKeyPair(),
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress);
|
||||
|
||||
environment.lifecycle().manage(localEventLoopGroup);
|
||||
environment.lifecycle().manage(dnsResolutionEventLoopGroup);
|
||||
environment.lifecycle().manage(anonymousGrpcServer);
|
||||
environment.lifecycle().manage(authenticatedGrpcServer);
|
||||
environment.lifecycle().manage(noiseTunnelEventLoopGroup);
|
||||
environment.lifecycle().manage(noiseWebSocketTunnelServer);
|
||||
environment.lifecycle().manage(noiseDirectTunnelServer);
|
||||
environment.lifecycle().manage(exposedGrpcServer);
|
||||
|
||||
final List<Filter> filters = new ArrayList<>();
|
||||
filters.add(remoteDeprecationFilter);
|
||||
|
||||
@@ -27,8 +27,6 @@ import java.util.concurrent.ScheduledExecutorService;
|
||||
import javax.annotation.Nullable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
|
||||
@@ -47,7 +45,6 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte[]> implements Managed {
|
||||
|
||||
private final FaultTolerantRedisClient pubSubClient;
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
private final Executor listenerEventExecutor;
|
||||
private final ScheduledExecutorService retryExecutor;
|
||||
|
||||
@@ -74,12 +71,10 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
|
||||
private record AccountIdentifierAndDeviceId(UUID accountIdentifier, byte deviceId) {}
|
||||
|
||||
public DisconnectionRequestManager(final FaultTolerantRedisClient pubSubClient,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final Executor listenerEventExecutor,
|
||||
final ScheduledExecutorService retryExecutor) {
|
||||
|
||||
this.pubSubClient = pubSubClient;
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
this.listenerEventExecutor = listenerEventExecutor;
|
||||
this.retryExecutor = retryExecutor;
|
||||
}
|
||||
@@ -223,8 +218,6 @@ public class DisconnectionRequestManager extends RedisPubSubAdapter<byte[], byte
|
||||
}
|
||||
|
||||
deviceIds.forEach(deviceId -> {
|
||||
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(accountIdentifier, deviceId));
|
||||
|
||||
listeners.getOrDefault(new AccountIdentifierAndDeviceId(accountIdentifier, deviceId), Collections.emptyList())
|
||||
.forEach(listener -> listenerEventExecutor.execute(() -> {
|
||||
try {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
AbstractAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
return grpcClientConnectionManager.getAuthenticatedDevice(call);
|
||||
}
|
||||
}
|
||||
@@ -1,40 +1,30 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
|
||||
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
|
||||
* device are closed with an {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined
|
||||
* (i.e. because the underlying remote channel closed before the {@code ServerCall} started), the interceptor will
|
||||
* reject the call with a status of {@code UNAVAILABLE}.
|
||||
* contain an authorization header in the request metdata. Calls with an associated authenticated device are closed with
|
||||
* an {@code UNAUTHENTICATED} status.
|
||||
*/
|
||||
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
|
||||
public ProhibitAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
super(grpcClientConnectionManager);
|
||||
}
|
||||
public class ProhibitAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-prohibited
|
||||
// service via an authenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.map(ignored -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL))
|
||||
.orElseGet(() -> next.startCall(call, headers));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
|
||||
final String authHeaderString = headers.get(Metadata.Key.of(RequireAuthenticationInterceptor.AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER));
|
||||
if (authHeaderString != null) {
|
||||
call.close(Status.UNAUTHENTICATED.withDescription("authorization header forbidden"), new Metadata());
|
||||
return new ServerCall.Listener<>() {};
|
||||
}
|
||||
return next.startCall(call, headers);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,43 +1,68 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.grpc.ServerInterceptorUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
|
||||
/**
|
||||
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
|
||||
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
|
||||
* status. If a call's authentication status cannot be determined (i.e. because the underlying remote channel closed
|
||||
* before the {@code ServerCall} started), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
|
||||
* A "require authentication" interceptor authenticates requests and attaches the {@link AuthenticatedDevice} to the
|
||||
* current gRPC context. Calls without authentication or with invalid credentials are closed with an
|
||||
* {@code UNAUTHENTICATED} status. If a call's authentication status cannot be determined (i.e. because the accounts
|
||||
* database is unavailable), the interceptor will reject the call with a status of {@code UNAVAILABLE}.
|
||||
*/
|
||||
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
|
||||
public class RequireAuthenticationInterceptor implements ServerInterceptor {
|
||||
|
||||
public RequireAuthenticationInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
super(grpcClientConnectionManager);
|
||||
static final String AUTHORIZATION_HEADER = "authorization";
|
||||
|
||||
private final AccountAuthenticator authenticator;
|
||||
|
||||
public RequireAuthenticationInterceptor(final AccountAuthenticator authenticator) {
|
||||
this.authenticator = authenticator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
|
||||
final String authHeaderString = headers.get(
|
||||
Metadata.Key.of(AUTHORIZATION_HEADER, Metadata.ASCII_STRING_MARSHALLER));
|
||||
|
||||
try {
|
||||
return getAuthenticatedDevice(call)
|
||||
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
|
||||
if (authHeaderString == null) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call,
|
||||
Status.UNAUTHENTICATED.withDescription("missing authorization header"));
|
||||
}
|
||||
|
||||
final Optional<BasicCredentials> basicCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeaderString);
|
||||
if (basicCredentials.isEmpty()) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call,
|
||||
Status.UNAUTHENTICATED.withDescription("malformed authorization header"));
|
||||
}
|
||||
|
||||
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> authenticated =
|
||||
authenticator.authenticate(basicCredentials.get());
|
||||
if (authenticated.isEmpty()) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call,
|
||||
Status.UNAUTHENTICATED.withDescription("invalid credentials"));
|
||||
}
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(
|
||||
authenticated.get().accountIdentifier(),
|
||||
authenticated.get().deviceId());
|
||||
|
||||
return Contexts.interceptCall(Context.current()
|
||||
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
|
||||
call, headers, next))
|
||||
// Status.INTERNAL may seem a little surprising here, but if a caller is reaching an authentication-required
|
||||
// service via an unauthenticated connection, then that's actually a server configuration issue and not a
|
||||
// problem with the client's request.
|
||||
.orElseGet(() -> ServerInterceptorUtil.closeWithStatus(call, Status.INTERNAL));
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
call, headers, next);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.configuration;
|
||||
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
|
||||
public record GrpcConfiguration(@NotNull String bindAddress, @NotNull Integer port) {
|
||||
public GrpcConfiguration {
|
||||
if (bindAddress == null || bindAddress.isEmpty()) {
|
||||
bindAddress = "localhost";
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.configuration;
|
||||
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import jakarta.validation.constraints.Positive;
|
||||
import javax.annotation.Nullable;
|
||||
import org.signal.libsignal.protocol.InvalidKeyException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPrivateKey;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
|
||||
|
||||
public record NoiseTunnelConfiguration(@Positive int webSocketPort,
|
||||
@Positive int directPort,
|
||||
@Nullable String tlsKeyStoreFile,
|
||||
@Nullable String tlsKeyStoreEntryAlias,
|
||||
@Nullable SecretString tlsKeyStorePassword,
|
||||
@NotNull SecretBytes noiseStaticPrivateKey,
|
||||
@NotNull SecretString recognizedProxySecret) {
|
||||
|
||||
public ECKeyPair noiseStaticKeyPair() throws InvalidKeyException {
|
||||
final ECPrivateKey privateKey = new ECPrivateKey(noiseStaticPrivateKey().value());
|
||||
|
||||
return new ECKeyPair(privateKey.publicKey(), privateKey);
|
||||
}
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import io.grpc.Context;
|
||||
import io.grpc.ForwardingServerCallListener;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
/**
|
||||
* Then channel shutdown interceptor rejects new requests if a channel is shutting down and works in tandem with
|
||||
* {@link GrpcClientConnectionManager} to maintain an active call count for each channel otherwise.
|
||||
*/
|
||||
public class ChannelShutdownInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
public ChannelShutdownInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
|
||||
if (!grpcClientConnectionManager.handleServerCallStart(call)) {
|
||||
// Don't allow new calls if the connection is getting ready to close
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
|
||||
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<>(next.startCall(call, headers)) {
|
||||
@Override
|
||||
public void onComplete() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onComplete();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
grpcClientConnectionManager.handleServerCallComplete(call);
|
||||
super.onCancel();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,41 +1,108 @@
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import com.google.common.net.HttpHeaders;
|
||||
import io.grpc.Context;
|
||||
import io.grpc.Contexts;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.Status;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The request attributes interceptor makes request attributes from the underlying remote channel available to service
|
||||
* The request attributes interceptor makes common request attributes from call metadata available to service
|
||||
* implementations by attaching them to a {@link Context} attribute that can be read via {@link RequestAttributesUtil}.
|
||||
* All server calls should have request attributes, and calls will be rejected with a status of {@code UNAVAILABLE} if
|
||||
* request attributes are unavailable (i.e. the underlying channel closed before the {@code ServerCall} started).
|
||||
*
|
||||
* @see RequestAttributesUtil
|
||||
*/
|
||||
public class RequestAttributesInterceptor implements ServerInterceptor {
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
|
||||
|
||||
public RequestAttributesInterceptor(final GrpcClientConnectionManager grpcClientConnectionManager) {
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
}
|
||||
private static final Metadata.Key<String> ACCEPT_LANG_KEY =
|
||||
Metadata.Key.of(HttpHeaders.ACCEPT_LANGUAGE, Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
private static final Metadata.Key<String> USER_AGENT_KEY =
|
||||
Metadata.Key.of(HttpHeaders.USER_AGENT, Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
private static final Metadata.Key<String> X_FORWARDED_FOR_KEY =
|
||||
Metadata.Key.of(HttpHeaders.X_FORWARDED_FOR, Metadata.ASCII_STRING_MARSHALLER);
|
||||
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers,
|
||||
final ServerCallHandler<ReqT, RespT> next) {
|
||||
final String userAgentHeader = headers.get(USER_AGENT_KEY);
|
||||
final String acceptLanguageHeader = headers.get(ACCEPT_LANG_KEY);
|
||||
final String xForwardedForHeader = headers.get(X_FORWARDED_FOR_KEY);
|
||||
|
||||
final Optional<InetAddress> remoteAddress = getMostRecentProxy(xForwardedForHeader)
|
||||
.flatMap(mostRecentProxy -> {
|
||||
try {
|
||||
return Contexts.interceptCall(Context.current()
|
||||
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY,
|
||||
grpcClientConnectionManager.getRequestAttributes(call)), call, headers, next);
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
return Optional.of(InetAddress.ofLiteral(mostRecentProxy));
|
||||
} catch (IllegalArgumentException e) {
|
||||
log.warn("Failed to parse most recent proxy {} as an IP address", mostRecentProxy, e);
|
||||
return Optional.empty();
|
||||
}
|
||||
})
|
||||
.or(() -> {
|
||||
log.warn("No usable X-Forwarded-For header present, using remote socket address");
|
||||
final SocketAddress socketAddress = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
|
||||
if (socketAddress == null || !(socketAddress instanceof InetSocketAddress inetAddress)) {
|
||||
log.warn("Remote socket address not present or is not an inet address: {}", socketAddress);
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(inetAddress.getAddress());
|
||||
});
|
||||
if (!remoteAddress.isPresent()) {
|
||||
return ServerInterceptorUtil.closeWithStatus(call, Status.UNAVAILABLE);
|
||||
}
|
||||
|
||||
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
|
||||
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
|
||||
try {
|
||||
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
|
||||
}
|
||||
}
|
||||
|
||||
final RequestAttributes requestAttributes =
|
||||
new RequestAttributes(remoteAddress.get(), userAgentHeader, acceptLanguages);
|
||||
return Contexts.interceptCall(
|
||||
Context.current().withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes),
|
||||
call, headers, next);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
|
||||
*
|
||||
* @param forwardedFor the value of an X-Forwarded-For header
|
||||
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
|
||||
* {@code forwardedFor} was null
|
||||
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
|
||||
* MDN</a>
|
||||
*/
|
||||
public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
|
||||
return Optional.ofNullable(forwardedFor)
|
||||
.map(ff -> {
|
||||
final int idx = forwardedFor.lastIndexOf(',') + 1;
|
||||
return idx < forwardedFor.length()
|
||||
? forwardedFor.substring(idx).trim()
|
||||
: null;
|
||||
})
|
||||
.filter(StringUtils::isNotBlank);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
|
||||
/**
|
||||
* An error handler serves as a general backstop for exceptions elsewhere in the pipeline. It translates exceptions
|
||||
* thrown in inbound handlers into {@link OutboundCloseErrorMessage}s.
|
||||
*/
|
||||
public class ErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
private static final Logger log = LoggerFactory.getLogger(ErrorHandler.class);
|
||||
|
||||
private static OutboundCloseErrorMessage NOISE_ENCRYPTION_ERROR_CLOSE = new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_ERROR,
|
||||
"Noise encryption error");
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||
final OutboundCloseErrorMessage closeMessage = switch (ExceptionUtils.unwrap(cause)) {
|
||||
case NoiseHandshakeException e -> new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.NOISE_HANDSHAKE_ERROR,
|
||||
e.getMessage());
|
||||
case BadPaddingException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
case NoiseException ignored -> NOISE_ENCRYPTION_ERROR_CLOSE;
|
||||
default -> {
|
||||
log.warn("An unexpected exception reached the end of the pipeline", cause);
|
||||
yield new OutboundCloseErrorMessage(
|
||||
OutboundCloseErrorMessage.Code.INTERNAL_SERVER_ERROR,
|
||||
cause.getMessage());
|
||||
}
|
||||
};
|
||||
|
||||
context.writeAndFlush(closeMessage)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
}
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Tag;
|
||||
import io.micrometer.core.instrument.Tags;
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
|
||||
/**
|
||||
* An "establish local connection" handler waits for a Noise handshake to complete upstream in the pipeline, buffering
|
||||
* any inbound messages until the connection is fully-established, and then opens a proxy connection to a local gRPC
|
||||
* server.
|
||||
*/
|
||||
public class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
|
||||
|
||||
private final GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private final LocalAddress authenticatedGrpcServerAddress;
|
||||
private final LocalAddress anonymousGrpcServerAddress;
|
||||
private final FramingType framingType;
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
private static final String CONNECTION_ESTABLISHED_COUNTER_NAME = MetricsUtil.name(EstablishLocalGrpcConnectionHandler.class, "established");
|
||||
|
||||
public EstablishLocalGrpcConnectionHandler(final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final FramingType framingType) {
|
||||
|
||||
this.grpcClientConnectionManager = grpcClientConnectionManager;
|
||||
|
||||
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
|
||||
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
|
||||
this.framingType = framingType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
pendingReads.add(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent(
|
||||
final Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress, String userAgent, String acceptLanguage)) {
|
||||
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
|
||||
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
|
||||
// connect to the anonymous service. If it does have an authenticated device, we assume we're aiming for the
|
||||
// authenticated service.
|
||||
final LocalAddress grpcServerAddress = authenticatedDevice.isPresent()
|
||||
? authenticatedGrpcServerAddress
|
||||
: anonymousGrpcServerAddress;
|
||||
|
||||
GrpcClientConnectionManager.handleHandshakeInitiated(
|
||||
remoteChannelContext.channel(), remoteAddress, userAgent, acceptLanguage);
|
||||
|
||||
final List<Tag> tags = UserAgentTagUtil.getLibsignalAndPlatformTags(userAgent);
|
||||
Metrics.counter(CONNECTION_ESTABLISHED_COUNTER_NAME, Tags.of(tags)
|
||||
.and("authenticated", Boolean.toString(authenticatedDevice.isPresent()))
|
||||
.and("framingType", framingType.name()))
|
||||
.increment();
|
||||
|
||||
new Bootstrap()
|
||||
.remoteAddress(grpcServerAddress)
|
||||
.channel(LocalChannel.class)
|
||||
.group(remoteChannelContext.channel().eventLoop())
|
||||
.handler(new ChannelInitializer<LocalChannel>() {
|
||||
@Override
|
||||
protected void initChannel(final LocalChannel localChannel) {
|
||||
localChannel.pipeline().addLast(new ProxyHandler(remoteChannelContext.channel()));
|
||||
}
|
||||
})
|
||||
.connect()
|
||||
.addListener((ChannelFutureListener) localChannelFuture -> {
|
||||
if (localChannelFuture.isSuccess()) {
|
||||
grpcClientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
|
||||
remoteChannelContext.channel(),
|
||||
authenticatedDevice);
|
||||
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
|
||||
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
|
||||
remoteChannelContext.channel()
|
||||
.write(new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed"))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
|
||||
|
||||
remoteChannelContext.pipeline()
|
||||
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
|
||||
|
||||
// Flush any buffered reads we accumulated while waiting to open the connection
|
||||
pendingReads.forEach(remoteChannelContext::fireChannelRead);
|
||||
pendingReads.clear();
|
||||
|
||||
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
|
||||
} else {
|
||||
log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
|
||||
remoteChannelContext.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
remoteChannelContext.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
pendingReads.forEach(ReferenceCountUtil::release);
|
||||
pendingReads.clear();
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
public enum FramingType {
|
||||
NOISE_DIRECT,
|
||||
WEBSOCKET
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.ServerCall;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.util.AttributeKey;
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||
import org.whispersystems.textsecuregcm.util.ClosableEpoch;
|
||||
|
||||
/**
|
||||
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
|
||||
* Noise tunnel. It provides access to metadata associated with the remote connection, including the authenticated
|
||||
* identity of the device that opened the connection (for non-anonymous connections). It can also close connections
|
||||
* associated with a given device if that device's credentials have changed and clients must reauthenticate.
|
||||
* <p>
|
||||
* In general, all {@link ServerCall}s <em>must</em> have a local address that in turn <em>should</em> be resolvable to
|
||||
* a remote channel, which <em>must</em> have associated request attributes and authentication status. It is possible
|
||||
* that a server call's local address may not be resolvable to a remote channel if the remote channel closed in the
|
||||
* narrow window between a server call being created and the start of call execution, in which case accessor methods
|
||||
* in this class will throw a {@link ChannelNotFoundException}.
|
||||
* <p>
|
||||
* A gRPC client connection manager's methods for getting request attributes accept {@link ServerCall} entities to
|
||||
* identify connections. In general, these methods should only be called from {@link io.grpc.ServerInterceptor}s.
|
||||
* Methods for requesting connection closure accept an {@link AuthenticatedDevice} to identify the connection and may
|
||||
* be called from any application code.
|
||||
*/
|
||||
public class GrpcClientConnectionManager {
|
||||
|
||||
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
|
||||
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<AuthenticatedDevice> AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "authenticatedDevice");
|
||||
|
||||
@VisibleForTesting
|
||||
public static final AttributeKey<RequestAttributes> REQUEST_ATTRIBUTES_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "requestAttributes");
|
||||
|
||||
@VisibleForTesting
|
||||
static final AttributeKey<ClosableEpoch> EPOCH_ATTRIBUTE_KEY =
|
||||
AttributeKey.valueOf(GrpcClientConnectionManager.class, "epoch");
|
||||
|
||||
private static final OutboundCloseErrorMessage SERVER_CLOSED =
|
||||
new OutboundCloseErrorMessage(OutboundCloseErrorMessage.Code.SERVER_CLOSED, "server closed");
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GrpcClientConnectionManager.class);
|
||||
|
||||
/**
|
||||
* Returns the authenticated device associated with the given server call, if any. If the connection is anonymous
|
||||
* (i.e. unauthenticated), the returned value will be empty.
|
||||
*
|
||||
* @param serverCall the gRPC server call for which to find an authenticated device
|
||||
*
|
||||
* @return the authenticated device associated with the given local address, if any
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> serverCall)
|
||||
throws ChannelNotFoundException {
|
||||
|
||||
return getAuthenticatedDevice(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Optional<AuthenticatedDevice> getAuthenticatedDevice(final Channel remoteChannel) {
|
||||
return Optional.ofNullable(remoteChannel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the request attributes associated with the given server call.
|
||||
*
|
||||
* @param serverCall the gRPC server call for which to retrieve request attributes
|
||||
*
|
||||
* @return the request attributes associated with the given server call
|
||||
*
|
||||
* @throws ChannelNotFoundException if the server call is not associated with a known channel; in practice, this
|
||||
* generally indicates that the channel has closed while request processing is still in progress
|
||||
*/
|
||||
public RequestAttributes getRequestAttributes(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRequestAttributes(getRemoteChannel(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
RequestAttributes getRequestAttributes(final Channel remoteChannel) {
|
||||
final RequestAttributes requestAttributes = remoteChannel.attr(REQUEST_ATTRIBUTES_KEY).get();
|
||||
|
||||
if (requestAttributes == null) {
|
||||
throw new IllegalStateException("Channel does not have request attributes");
|
||||
}
|
||||
|
||||
return requestAttributes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the start of a server call, incrementing the active call count for the remote channel associated with the
|
||||
* given server call.
|
||||
*
|
||||
* @param serverCall the server call to start
|
||||
*
|
||||
* @return {@code true} if the call should start normally or {@code false} if the call should be aborted because the
|
||||
* underlying channel is closing
|
||||
*/
|
||||
public boolean handleServerCallStart(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
return getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().tryArrive();
|
||||
} catch (final ChannelNotFoundException e) {
|
||||
// This would only happen if the channel had already closed, which is certainly possible. In this case, the call
|
||||
// should certainly not proceed.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles completion (successful or not) of a server call, decrementing the active call count for the remote channel
|
||||
* associated with the given server call.
|
||||
*
|
||||
* @param serverCall the server call to complete
|
||||
*/
|
||||
public void handleServerCallComplete(final ServerCall<?, ?> serverCall) {
|
||||
try {
|
||||
getRemoteChannel(serverCall).attr(EPOCH_ATTRIBUTE_KEY).get().depart();
|
||||
} catch (final ChannelNotFoundException ignored) {
|
||||
// In practice, we'd only get here if the channel has already closed, so we can just ignore the exception
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes any client connections to this host associated with the given authenticated device.
|
||||
*
|
||||
* @param authenticatedDevice the authenticated device for which to close connections
|
||||
*/
|
||||
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
|
||||
// Channels will actually get removed from the list/map by their closeFuture listeners. We copy the list to avoid
|
||||
// concurrent modification; it's possible (though practically unlikely) that a channel can close and remove itself
|
||||
// from the list while we're still iterating, resulting in a `ConcurrentModificationException`.
|
||||
final List<Channel> channelsToClose =
|
||||
new ArrayList<>(remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()));
|
||||
|
||||
channelsToClose.forEach(channel -> channel.attr(EPOCH_ATTRIBUTE_KEY).get().close());
|
||||
}
|
||||
|
||||
private static void closeRemoteChannel(final Channel channel) {
|
||||
channel.writeAndFlush(SERVER_CLOSED).addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Nullable List<Channel> getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
|
||||
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
|
||||
}
|
||||
|
||||
private Channel getRemoteChannel(final ServerCall<?, ?> serverCall) throws ChannelNotFoundException {
|
||||
return getRemoteChannel(getLocalAddress(serverCall));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
Channel getRemoteChannel(final LocalAddress localAddress) throws ChannelNotFoundException {
|
||||
final Channel remoteChannel = remoteChannelsByLocalAddress.get(localAddress);
|
||||
|
||||
if (remoteChannel == null) {
|
||||
throw new ChannelNotFoundException();
|
||||
}
|
||||
|
||||
return remoteChannelsByLocalAddress.get(localAddress);
|
||||
}
|
||||
|
||||
private static LocalAddress getLocalAddress(final ServerCall<?, ?> serverCall) {
|
||||
// In this server, gRPC's "remote" channel is actually a local channel that proxies to a distinct Noise channel.
|
||||
// The gRPC "remote" address is the "local address" for the proxy connection, and the local address uniquely maps to
|
||||
// a proxied Noise channel.
|
||||
if (!(serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress)) {
|
||||
throw new IllegalArgumentException("Unexpected channel type: " + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
|
||||
}
|
||||
|
||||
return localAddress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles receipt of a handshake message and associates attributes and headers from the handshake
|
||||
* request with the channel via which the handshake took place.
|
||||
*
|
||||
* @param channel the channel where the handshake was initiated
|
||||
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
|
||||
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
|
||||
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
|
||||
* {@code null}
|
||||
*/
|
||||
public static void handleHandshakeInitiated(final Channel channel,
|
||||
final InetAddress preferredRemoteAddress,
|
||||
@Nullable final String userAgentHeader,
|
||||
@Nullable final String acceptLanguageHeader) {
|
||||
|
||||
@Nullable List<Locale.LanguageRange> acceptLanguages = Collections.emptyList();
|
||||
|
||||
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
|
||||
try {
|
||||
acceptLanguages = Locale.LanguageRange.parse(acceptLanguageHeader);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
|
||||
}
|
||||
}
|
||||
|
||||
channel.attr(REQUEST_ATTRIBUTES_KEY)
|
||||
.set(new RequestAttributes(preferredRemoteAddress, userAgentHeader, acceptLanguages));
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles successful establishment of a Noise connection from a remote client to a local gRPC server.
|
||||
*
|
||||
* @param localChannel the newly-opened local channel between the Noise tunnel and the local gRPC server
|
||||
* @param remoteChannel the channel from the remote client to the Noise tunnel
|
||||
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
|
||||
*/
|
||||
void handleConnectionEstablished(final LocalChannel localChannel,
|
||||
final Channel remoteChannel,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
|
||||
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
|
||||
remoteChannel.attr(GrpcClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
|
||||
|
||||
remoteChannel.attr(EPOCH_ATTRIBUTE_KEY)
|
||||
.set(new ClosableEpoch(() -> closeRemoteChannel(remoteChannel)));
|
||||
|
||||
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
|
||||
final List<Channel> channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
|
||||
channels.add(remoteChannel);
|
||||
|
||||
return channels;
|
||||
}));
|
||||
|
||||
remoteChannel.closeFuture().addListener(closeFuture -> {
|
||||
remoteChannelsByLocalAddress.remove(localChannel.localAddress());
|
||||
|
||||
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
|
||||
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
|
||||
if (existingChannelList == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
existingChannelList.remove(remoteChannel);
|
||||
|
||||
return existingChannelList.isEmpty() ? null : existingChannelList;
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* An HAProxy message handler handles decoded HAProxyMessage instances, removing itself from the pipeline once it has
|
||||
* either handled a proxy protocol message or determined that no such message is coming.
|
||||
*/
|
||||
public class HAProxyMessageHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(HAProxyMessageHandler.class);
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof HAProxyMessage haProxyMessage) {
|
||||
// Some network/deployment configurations will send us a proxy protocol message, but we don't use it. We still
|
||||
// need to clear it from the pipeline to avoid confusing the TLS machinery, though.
|
||||
log.debug("Discarding HAProxy message: {}", haProxyMessage);
|
||||
haProxyMessage.release();
|
||||
} else {
|
||||
super.channelRead(context, message);
|
||||
}
|
||||
|
||||
// Regardless of the type of the first message, we'll only ever receive zero or one HAProxyMessages. After the first
|
||||
// message, all others will just be "normal" messages, and our work here is done.
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
public enum HandshakePattern {
|
||||
NK("Noise_NK_25519_ChaChaPoly_BLAKE2b"),
|
||||
IK("Noise_IK_25519_ChaChaPoly_BLAKE2b");
|
||||
|
||||
private final String protocol;
|
||||
|
||||
public String protocol() {
|
||||
return protocol;
|
||||
}
|
||||
|
||||
|
||||
HandshakePattern(String protocol) {
|
||||
this.protocol = protocol;
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
|
||||
/**
|
||||
* A wrapper for a Netty {@link DefaultEventLoopGroup} that implements Dropwizard's {@link Managed} interface, allowing
|
||||
* Dropwizard to manage the lifecycle of the event loop group.
|
||||
*/
|
||||
public class ManagedDefaultEventLoopGroup extends DefaultEventLoopGroup implements Managed {
|
||||
|
||||
@Override
|
||||
public void stop() throws InterruptedException {
|
||||
this.shutdownGracefully().await();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.grpc.Server;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class ManagedGrpcServer implements Managed {
|
||||
private final Server server;
|
||||
|
||||
public ManagedGrpcServer(Server server) {
|
||||
this.server = server;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws IOException {
|
||||
server.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() {
|
||||
try {
|
||||
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
|
||||
} catch (final InterruptedException e) {
|
||||
server.shutdownNow();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* A managed, local gRPC server configures and wraps a gRPC {@link Server} that listens on a Netty {@link LocalAddress}
|
||||
* and whose lifecycle is managed by Dropwizard via the {@link Managed} interface.
|
||||
*/
|
||||
public abstract class ManagedLocalGrpcServer implements Managed {
|
||||
|
||||
private final Server server;
|
||||
|
||||
public ManagedLocalGrpcServer(final LocalAddress localAddress,
|
||||
final DefaultEventLoopGroup eventLoopGroup) {
|
||||
|
||||
final ServerBuilder<?> serverBuilder = NettyServerBuilder.forAddress(localAddress)
|
||||
.channelType(LocalServerChannel.class)
|
||||
.bossEventLoopGroup(eventLoopGroup)
|
||||
.workerEventLoopGroup(eventLoopGroup);
|
||||
|
||||
configureServer(serverBuilder);
|
||||
|
||||
server = serverBuilder.build();
|
||||
}
|
||||
|
||||
protected abstract void configureServer(final ServerBuilder<?> serverBuilder);
|
||||
|
||||
@Override
|
||||
public void start() throws IOException {
|
||||
server.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() {
|
||||
try {
|
||||
server.shutdown().awaitTermination(5, TimeUnit.MINUTES);
|
||||
} catch (final InterruptedException e) {
|
||||
server.shutdownNow();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while processing an encrypted noise message (e.g. an unexpected message size/
|
||||
* format or a general encryption error).
|
||||
*/
|
||||
public class NoiseException extends NoStackTraceException {
|
||||
public NoiseException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -1,119 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherState;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.PromiseCombiner;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A bidirectional {@link io.netty.channel.ChannelHandler} that decrypts inbound messages, and encrypts outbound
|
||||
* messages
|
||||
*/
|
||||
public class NoiseHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseHandler.class);
|
||||
private final CipherStatePair cipherStatePair;
|
||||
|
||||
NoiseHandler(CipherStatePair cipherStatePair) {
|
||||
this.cipherStatePair = cipherStatePair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof ByteBuf frame) {
|
||||
if (frame.readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseException("Invalid noise message length " + frame.readableBytes());
|
||||
}
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer.
|
||||
handleInboundDataMessage(context, ByteBufUtil.getBytes(frame));
|
||||
} else {
|
||||
// Anything except ByteBufs should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void handleInboundDataMessage(final ChannelHandlerContext context, final byte[] frameBytes)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||
final int plaintextLength = cipherState.decryptWithAd(null,
|
||||
frameBytes, 0,
|
||||
frameBytes, 0,
|
||||
frameBytes.length);
|
||||
|
||||
// Forward the decrypted plaintext along
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(frameBytes, 0, plaintextLength));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (message instanceof ByteBuf byteBuf) {
|
||||
try {
|
||||
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
||||
final CipherState cipherState = cipherStatePair.getSender();
|
||||
|
||||
// Server message might not fit in a single noise packet, break it up into as many chunks as we need
|
||||
final PromiseCombiner pc = new PromiseCombiner(context.executor());
|
||||
while (byteBuf.isReadable()) {
|
||||
final ByteBuf plaintext = byteBuf.readSlice(Math.min(
|
||||
// need room for a 16-byte AEAD tag
|
||||
Noise.MAX_PACKET_LEN - 16,
|
||||
byteBuf.readableBytes()));
|
||||
|
||||
final int plaintextLength = plaintext.readableBytes();
|
||||
|
||||
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
||||
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
||||
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
||||
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
||||
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
||||
|
||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||
|
||||
pc.add(context.write(Unpooled.wrappedBuffer(noiseBuffer)));
|
||||
}
|
||||
pc.finish(promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof OutboundCloseErrorMessage)) {
|
||||
// Downstream handlers may write OutboundCloseErrorMessages that don't need to be encrypted (e.g. "close" frames
|
||||
// that get issued in response to exceptions)
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(ChannelHandlerContext var1) {
|
||||
if (cipherStatePair != null) {
|
||||
cipherStatePair.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import org.whispersystems.textsecuregcm.util.NoStackTraceException;
|
||||
|
||||
/**
|
||||
* Indicates that some problem occurred while completing a Noise handshake (e.g. an unexpected message size/format or
|
||||
* a general encryption error).
|
||||
*/
|
||||
public class NoiseHandshakeException extends NoStackTraceException {
|
||||
|
||||
public NoiseHandshakeException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -1,197 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufInputStream;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.DeviceIdUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
/**
|
||||
* Handles the responder side of a noise handshake and then replaces itself with a {@link NoiseHandler} which will
|
||||
* encrypt/decrypt subsequent data frames
|
||||
* <p>
|
||||
* The handler expects to receive a single inbound message, a {@link NoiseHandshakeInit} that includes the initiator
|
||||
* handshake message, connection metadata, and the type of handshake determined by the framing layer. This handler
|
||||
* currently supports two types of handshakes.
|
||||
* <p>
|
||||
* The first are IK handshakes where the initiator's static public key is authenticated by the responder. The initiator
|
||||
* handshake message must contain the ACI and deviceId of the initiator. To be authenticated, the static key provided in
|
||||
* the handshake message must match the server's stored public key for the device identified by the provided ACI and
|
||||
* deviceId.
|
||||
* <p>
|
||||
* The second are NK handshakes which are anonymous.
|
||||
* <p>
|
||||
* Optionally, the initiator can also include an initial request in their payload. If provided, this allows the server
|
||||
* to begin processing the request without an initial message delay (fast open).
|
||||
* <p>
|
||||
* Once the handshake has been validated, a {@link NoiseIdentityDeterminedEvent} will be fired. For an IK handshake,
|
||||
* this will include the {@link org.whispersystems.textsecuregcm.auth.AuthenticatedDevice} of the initiator. This
|
||||
* handler will then replace itself with a {@link NoiseHandler} with a noise state pair ready to encrypt/decrypt data
|
||||
* frames.
|
||||
*/
|
||||
public class NoiseHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private static final byte[] HANDSHAKE_WRONG_PK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY)
|
||||
.build().toByteArray();
|
||||
private static final byte[] HANDSHAKE_OK = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||
.build().toByteArray();
|
||||
|
||||
// We might get additional messages while we're waiting to process a handshake, so keep track of where we are
|
||||
private boolean receivedHandshakeInit = false;
|
||||
|
||||
private final ClientPublicKeysManager clientPublicKeysManager;
|
||||
private final ECKeyPair ecKeyPair;
|
||||
|
||||
public NoiseHandshakeHandler(final ClientPublicKeysManager clientPublicKeysManager, final ECKeyPair ecKeyPair) {
|
||||
this.clientPublicKeysManager = clientPublicKeysManager;
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (!(message instanceof NoiseHandshakeInit handshakeInit)) {
|
||||
// Anything except HandshakeInit should have been filtered out of the pipeline by now; treat this as an error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
if (receivedHandshakeInit) {
|
||||
throw new NoiseHandshakeException("Should not receive messages until handshake complete");
|
||||
}
|
||||
receivedHandshakeInit = true;
|
||||
|
||||
if (handshakeInit.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseHandshakeException("Invalid noise message length " + handshakeInit.content().readableBytes());
|
||||
}
|
||||
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer
|
||||
handleInboundHandshake(context,
|
||||
handshakeInit.getRemoteAddress(),
|
||||
handshakeInit.getHandshakePattern(),
|
||||
ByteBufUtil.getBytes(handshakeInit.content()));
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInboundHandshake(
|
||||
final ChannelHandlerContext context,
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final byte[] frameBytes) throws NoiseHandshakeException {
|
||||
final NoiseHandshakeHelper handshakeHelper = new NoiseHandshakeHelper(handshakePattern, ecKeyPair);
|
||||
final ByteBuf payload = handshakeHelper.read(frameBytes);
|
||||
|
||||
// Parse the handshake message
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||
try {
|
||||
handshakeInit = NoiseTunnelProtos.HandshakeInit.parseFrom(new ByteBufInputStream(payload));
|
||||
} catch (IOException e) {
|
||||
throw new NoiseHandshakeException("Failed to parse handshake message");
|
||||
}
|
||||
|
||||
switch (handshakePattern) {
|
||||
case NK -> {
|
||||
if (handshakeInit.getDeviceId() != 0 || !handshakeInit.getAci().isEmpty()) {
|
||||
throw new NoiseHandshakeException("Anonymous handshake should not include identifiers");
|
||||
}
|
||||
handleAuthenticated(context, handshakeHelper, remoteAddress, handshakeInit, Optional.empty());
|
||||
}
|
||||
case IK -> {
|
||||
final byte[] publicKeyFromClient = handshakeHelper.remotePublicKey()
|
||||
.orElseThrow(() -> new IllegalStateException("No remote public key"));
|
||||
final UUID accountIdentifier = aci(handshakeInit);
|
||||
final byte deviceId = deviceId(handshakeInit);
|
||||
clientPublicKeysManager
|
||||
.findPublicKey(accountIdentifier, deviceId)
|
||||
.whenCompleteAsync((storedPublicKey, throwable) -> {
|
||||
if (throwable != null) {
|
||||
context.fireExceptionCaught(ExceptionUtils.unwrap(throwable));
|
||||
return;
|
||||
}
|
||||
final boolean valid = storedPublicKey
|
||||
.map(spk -> MessageDigest.isEqual(publicKeyFromClient, spk.getPublicKeyBytes()))
|
||||
.orElse(false);
|
||||
if (!valid) {
|
||||
// Write a handshake response indicating that the client used the wrong public key
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_WRONG_PK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
|
||||
context.fireExceptionCaught(new NoiseHandshakeException("Bad public key"));
|
||||
return;
|
||||
}
|
||||
handleAuthenticated(context,
|
||||
handshakeHelper, remoteAddress, handshakeInit,
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)));
|
||||
}, context.executor());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void handleAuthenticated(final ChannelHandlerContext context,
|
||||
final NoiseHandshakeHelper handshakeHelper,
|
||||
final InetAddress remoteAddress,
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit,
|
||||
final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
context.fireUserEventTriggered(new NoiseIdentityDeterminedEvent(
|
||||
maybeAuthenticatedDevice,
|
||||
remoteAddress,
|
||||
handshakeInit.getUserAgent(),
|
||||
handshakeInit.getAcceptLanguage()));
|
||||
|
||||
// Now that we've authenticated, write the handshake response
|
||||
final byte[] handshakeMessage = handshakeHelper.write(HANDSHAKE_OK);
|
||||
context.writeAndFlush(Unpooled.wrappedBuffer(handshakeMessage))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
|
||||
// The handshake is complete. We can start intercepting read/write for noise encryption/decryption
|
||||
// Note: It may be tempting to swap the before/remove for a replace, but then when we forward the fast open
|
||||
// request it will go through the NoiseHandler. We want to skip the NoiseHandler because we've already
|
||||
// decrypted the fastOpen request
|
||||
context.pipeline()
|
||||
.addBefore(context.name(), null, new NoiseHandler(handshakeHelper.getHandshakeState().split()));
|
||||
context.pipeline().remove(NoiseHandshakeHandler.class);
|
||||
if (!handshakeInit.getFastOpenRequest().isEmpty()) {
|
||||
// The handshake had a fast-open request. Forward the plaintext of the request to the server, we'll
|
||||
// encrypt the response when the server writes back through us
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(handshakeInit.getFastOpenRequest().asReadOnlyByteBuffer()));
|
||||
}
|
||||
}
|
||||
|
||||
private static UUID aci(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
try {
|
||||
return UUIDUtil.fromByteString(handshakePayload.getAci());
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new NoiseHandshakeException("Could not parse aci");
|
||||
}
|
||||
}
|
||||
|
||||
private static byte deviceId(final NoiseTunnelProtos.HandshakeInit handshakePayload) throws NoiseHandshakeException {
|
||||
if (!DeviceIdUtil.isValid(handshakePayload.getDeviceId())) {
|
||||
throw new NoiseHandshakeException("Invalid deviceId");
|
||||
}
|
||||
return (byte) handshakePayload.getDeviceId();
|
||||
}
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
|
||||
/**
|
||||
* Helper for the responder of a 2-message handshake with a pre-shared responder static key
|
||||
*/
|
||||
class NoiseHandshakeHelper {
|
||||
|
||||
private final static int AEAD_TAG_LENGTH = 16;
|
||||
private final static int KEY_LENGTH = 32;
|
||||
|
||||
private final HandshakePattern handshakePattern;
|
||||
private final ECKeyPair serverStaticKeyPair;
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
NoiseHandshakeHelper(HandshakePattern handshakePattern, ECKeyPair serverStaticKeyPair) {
|
||||
this.handshakePattern = handshakePattern;
|
||||
this.serverStaticKeyPair = serverStaticKeyPair;
|
||||
try {
|
||||
this.handshakeState = new HandshakeState(handshakePattern.protocol(), HandshakeState.RESPONDER);
|
||||
} catch (final NoSuchAlgorithmException e) {
|
||||
throw new AssertionError("Unsupported Noise algorithm: " + handshakePattern.protocol(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the length of the initiator's keys
|
||||
*
|
||||
* @return length of the handshake message sent by the remote party (the initiator) not including the payload
|
||||
*/
|
||||
private int initiatorHandshakeMessageKeyLength() {
|
||||
return switch (handshakePattern) {
|
||||
// ephemeral key, static key (encrypted), AEAD tag for static key
|
||||
case IK -> KEY_LENGTH + KEY_LENGTH + AEAD_TAG_LENGTH;
|
||||
// ephemeral key only
|
||||
case NK -> KEY_LENGTH;
|
||||
};
|
||||
}
|
||||
|
||||
HandshakeState getHandshakeState() {
|
||||
return this.handshakeState;
|
||||
}
|
||||
|
||||
ByteBuf read(byte[] remoteHandshakeMessage) throws NoiseHandshakeException {
|
||||
if (handshakeState.getAction() != HandshakeState.NO_ACTION) {
|
||||
throw new NoiseHandshakeException("Cannot send more data before handshake is complete");
|
||||
}
|
||||
|
||||
// Length for an empty payload
|
||||
final int minMessageLength = initiatorHandshakeMessageKeyLength() + AEAD_TAG_LENGTH;
|
||||
if (remoteHandshakeMessage.length < minMessageLength || remoteHandshakeMessage.length > Noise.MAX_PACKET_LEN) {
|
||||
throw new NoiseHandshakeException("Unexpected ephemeral key message length");
|
||||
}
|
||||
|
||||
final int payloadLength = remoteHandshakeMessage.length - initiatorHandshakeMessageKeyLength() - AEAD_TAG_LENGTH;
|
||||
|
||||
// Cryptographically initializing a handshake is expensive, and so we defer it until we're confident the client is
|
||||
// making a good-faith effort to perform a handshake (i.e. now). Noise-java in particular will derive a public key
|
||||
// from the supplied private key (and will in fact overwrite any previously-set public key when setting a private
|
||||
// key), so we just set the private key here.
|
||||
handshakeState.getLocalKeyPair().setPrivateKey(serverStaticKeyPair.getPrivateKey().serialize(), 0);
|
||||
handshakeState.start();
|
||||
|
||||
int payloadBytesRead;
|
||||
|
||||
try {
|
||||
payloadBytesRead = handshakeState.readMessage(remoteHandshakeMessage, 0, remoteHandshakeMessage.length,
|
||||
remoteHandshakeMessage, 0);
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen since we're checking the length of the frame up front
|
||||
throw new NoiseHandshakeException("Unexpected client payload");
|
||||
} catch (final BadPaddingException e) {
|
||||
// We aren't using padding but may get this error if the AEAD tag does not match the encrypted client static key
|
||||
// or payload
|
||||
throw new NoiseHandshakeException("Invalid keys or payload");
|
||||
}
|
||||
if (payloadBytesRead != payloadLength) {
|
||||
throw new NoiseHandshakeException(
|
||||
"Unexpected payload length, required " + payloadLength + " but got " + payloadBytesRead);
|
||||
}
|
||||
return Unpooled.wrappedBuffer(remoteHandshakeMessage, 0, payloadBytesRead);
|
||||
}
|
||||
|
||||
byte[] write(byte[] payload) {
|
||||
if (handshakeState.getAction() != HandshakeState.WRITE_MESSAGE) {
|
||||
throw new IllegalStateException("Cannot send data before handshake is complete");
|
||||
}
|
||||
|
||||
// Currently only support handshake patterns where the server static key is known
|
||||
// Send our ephemeral key and the response to the initiator with the encrypted payload
|
||||
final byte[] response = new byte[KEY_LENGTH + payload.length + AEAD_TAG_LENGTH];
|
||||
try {
|
||||
int written = handshakeState.writeMessage(response, 0, payload, 0, payload.length);
|
||||
if (written != response.length) {
|
||||
throw new IllegalStateException("Unexpected handshake response length");
|
||||
}
|
||||
return response;
|
||||
} catch (final ShortBufferException e) {
|
||||
// This should never happen for messages of known length that we control
|
||||
throw new IllegalStateException("Key material buffer was too short for message", e);
|
||||
}
|
||||
}
|
||||
|
||||
Optional<byte[]> remotePublicKey() {
|
||||
return Optional.ofNullable(handshakeState.getRemotePublicKey()).map(dhstate -> {
|
||||
final byte[] publicKeyFromClient = new byte[handshakeState.getRemotePublicKey().getPublicKeyLength()];
|
||||
handshakeState.getRemotePublicKey().getPublicKey(publicKeyFromClient, 0);
|
||||
return publicKeyFromClient;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
import java.net.InetAddress;
|
||||
|
||||
/**
|
||||
* A message that includes the initiator's handshake message, connection metadata, and the handshake type. The metadata
|
||||
* and handshake type are extracted from the framing layer, so this allows receivers to be framing layer agnostic.
|
||||
*/
|
||||
public class NoiseHandshakeInit extends DefaultByteBufHolder {
|
||||
|
||||
private final InetAddress remoteAddress;
|
||||
private final HandshakePattern handshakePattern;
|
||||
|
||||
public NoiseHandshakeInit(
|
||||
final InetAddress remoteAddress,
|
||||
final HandshakePattern handshakePattern,
|
||||
final ByteBuf initiatorHandshakeMessage) {
|
||||
super(initiatorHandshakeMessage);
|
||||
this.remoteAddress = remoteAddress;
|
||||
this.handshakePattern = handshakePattern;
|
||||
}
|
||||
|
||||
public InetAddress getRemoteAddress() {
|
||||
return remoteAddress;
|
||||
}
|
||||
|
||||
public HandshakePattern getHandshakePattern() {
|
||||
return handshakePattern;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
|
||||
/**
|
||||
* An event that indicates that an identity of a noise handshake initiator has been determined. If the initiator is
|
||||
* connecting anonymously, the identity is empty, otherwise it will be present and already authenticated.
|
||||
*
|
||||
* @param authenticatedDevice the device authenticated as part of the handshake, or empty if the handshake was not of a
|
||||
* type that performs authentication
|
||||
* @param remoteAddress the remote address of the connecting client
|
||||
* @param userAgent the client supplied userAgent
|
||||
* @param acceptLanguage the client supplied acceptLanguage
|
||||
*/
|
||||
public record NoiseIdentityDeterminedEvent(
|
||||
Optional<AuthenticatedDevice> authenticatedDevice,
|
||||
InetAddress remoteAddress,
|
||||
String userAgent,
|
||||
String acceptLanguage) {
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
/**
|
||||
* An error written to the outbound pipeline that indicates the connection should be closed
|
||||
*/
|
||||
public record OutboundCloseErrorMessage(Code code, String message) {
|
||||
public enum Code {
|
||||
|
||||
/**
|
||||
* The server decided to close the connection. This could be because the server is going away, or it could be
|
||||
* because the credentials for the connected client have been updated.
|
||||
*/
|
||||
SERVER_CLOSED,
|
||||
|
||||
/**
|
||||
* There was a noise decryption error after the noise session was established
|
||||
*/
|
||||
NOISE_ERROR,
|
||||
|
||||
/**
|
||||
* There was an error establishing the noise handshake
|
||||
*/
|
||||
NOISE_HANDSHAKE_ERROR,
|
||||
|
||||
INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
|
||||
/**
|
||||
* A proxy handler writes all data read from one channel to another peer channel.
|
||||
*/
|
||||
public class ProxyHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final Channel peerChannel;
|
||||
|
||||
public ProxyHandler(final Channel peerChannel) {
|
||||
this.peerChannel = peerChannel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
peerChannel.writeAndFlush(message)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
}
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
|
||||
|
||||
/**
|
||||
* A proxy protocol detection handler watches for HAProxy PROXY protocol messages at the beginning of a TCP connection.
|
||||
* If a connection begins with a proxy message, this handler will add a {@link HAProxyMessageDecoder} to the pipeline.
|
||||
* In all cases, once this handler has determined that a connection does or does not begin with a proxy protocol
|
||||
* message, it will remove itself from the pipeline and pass any intercepted down the pipeline.
|
||||
*
|
||||
* @see <a href="https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">The PROXY protocol</a>
|
||||
*/
|
||||
public class ProxyProtocolDetectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private CompositeByteBuf accumulator;
|
||||
|
||||
@VisibleForTesting
|
||||
static final int PROXY_MESSAGE_DETECTION_BYTES = 12;
|
||||
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext context) {
|
||||
// We need at least 12 bytes to decide if a byte buffer contains a proxy protocol message. Assuming we only get
|
||||
// non-empty buffers, that means we'll need at most 12 sub-buffers to have a complete message. In virtually every
|
||||
// practical case, though, we'll be able to tell from the first packet.
|
||||
accumulator = new CompositeByteBuf(context.alloc(), false, PROXY_MESSAGE_DETECTION_BYTES);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof ByteBuf byteBuf) {
|
||||
accumulator.addComponent(true, byteBuf);
|
||||
|
||||
switch (HAProxyMessageDecoder.detectProtocol(accumulator).state()) {
|
||||
case NEEDS_MORE_DATA -> {
|
||||
}
|
||||
|
||||
case INVALID -> {
|
||||
// We have enough information to determine that this connection is NOT starting with a proxy protocol message,
|
||||
// and we can just pass the accumulated bytes through
|
||||
context.fireChannelRead(accumulator);
|
||||
|
||||
accumulator = null;
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
|
||||
case DETECTED -> {
|
||||
// We have enough information to know that we're dealing with a proxy protocol message; add appropriate
|
||||
// handlers and pass the accumulated bytes through
|
||||
context.pipeline().addAfter(context.name(), null, new HAProxyMessageDecoder());
|
||||
context.fireChannelRead(accumulator);
|
||||
|
||||
accumulator = null;
|
||||
context.pipeline().remove(this);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
super.channelRead(context, message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
if (accumulator != null) {
|
||||
accumulator.release();
|
||||
accumulator = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseException;
|
||||
|
||||
/**
|
||||
* In the inbound direction, this handler strips the NoiseDirectFrame wrapper we read off the wire and then forwards the
|
||||
* noise packet to the noise layer as a {@link ByteBuf} for decryption.
|
||||
* <p>
|
||||
* In the outbound direction, this handler wraps encrypted noise packet {@link ByteBuf}s in a NoiseDirectFrame wrapper
|
||||
* so it can be wire serialized. This handler assumes the first outbound message received will correspond to the
|
||||
* handshake response, and then the subsequent messages are all data frame payloads.
|
||||
*/
|
||||
public class NoiseDirectDataFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
if (frame.frameType() != NoiseDirectFrame.FrameType.DATA) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw new NoiseException("Invalid frame type received (expected DATA): " + frame.frameType());
|
||||
}
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new NoiseDirectFrame(NoiseDirectFrame.FrameType.DATA, bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.DefaultByteBufHolder;
|
||||
|
||||
public class NoiseDirectFrame extends DefaultByteBufHolder {
|
||||
|
||||
static final byte VERSION = 0x00;
|
||||
|
||||
private final FrameType frameType;
|
||||
|
||||
public NoiseDirectFrame(final FrameType frameType, final ByteBuf data) {
|
||||
super(data);
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public FrameType frameType() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public byte versionedFrameTypeByte() {
|
||||
final byte frameBits = frameType().getFrameBits();
|
||||
return (byte) ((NoiseDirectFrame.VERSION << 4) | frameBits);
|
||||
}
|
||||
|
||||
|
||||
public enum FrameType {
|
||||
/**
|
||||
* The payload is the initiator message for a Noise NK handshake. If established, the
|
||||
* session will be unauthenticated.
|
||||
*/
|
||||
NK_HANDSHAKE((byte) 1),
|
||||
/**
|
||||
* The payload is the initiator message for a Noise IK handshake. If established, the
|
||||
* session will be authenticated.
|
||||
*/
|
||||
IK_HANDSHAKE((byte) 2),
|
||||
/**
|
||||
* The payload is an encrypted noise packet.
|
||||
*/
|
||||
DATA((byte) 3),
|
||||
/**
|
||||
* A frame sent before the connection is closed. The payload is a protobuf indicating why the connection is being
|
||||
* closed.
|
||||
*/
|
||||
CLOSE((byte) 4);
|
||||
|
||||
private final byte frameType;
|
||||
|
||||
FrameType(byte frameType) {
|
||||
if (frameType != (0x0F & frameType)) {
|
||||
throw new IllegalStateException("Frame type must fit in 4 bits");
|
||||
}
|
||||
this.frameType = frameType;
|
||||
}
|
||||
|
||||
public byte getFrameBits() {
|
||||
return frameType;
|
||||
}
|
||||
|
||||
public boolean isHandshake() {
|
||||
return switch (this) {
|
||||
case IK_HANDSHAKE, NK_HANDSHAKE -> true;
|
||||
case DATA, CLOSE -> false;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
|
||||
/**
|
||||
* Handles conversion between bytes on the wire and {@link NoiseDirectFrame}s. This handler assumes that inbound bytes
|
||||
* have already been framed using a {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
|
||||
*/
|
||||
public class NoiseDirectFrameCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof ByteBuf byteBuf) {
|
||||
try {
|
||||
ctx.fireChannelRead(deserialize(byteBuf));
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(byteBuf);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof NoiseDirectFrame noiseDirectFrame) {
|
||||
try {
|
||||
// Serialize the frame into a newly allocated direct buffer. Since this is the last handler before the
|
||||
// network, nothing should have to make another copy of this. If later another layer is added, it may be more
|
||||
// efficient to reuse the input buffer (typically not direct) by using a composite byte buffer
|
||||
final ByteBuf serialized = serialize(ctx, noiseDirectFrame);
|
||||
ctx.writeAndFlush(serialized, promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(noiseDirectFrame);
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf serialize(
|
||||
final ChannelHandlerContext ctx,
|
||||
final NoiseDirectFrame noiseDirectFrame) {
|
||||
if (noiseDirectFrame.content().readableBytes() > Noise.MAX_PACKET_LEN) {
|
||||
throw new IllegalStateException("Payload too long: " + noiseDirectFrame.content().readableBytes());
|
||||
}
|
||||
|
||||
// 1 version/frametype byte, 2 length bytes, content
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(1 + 2 + noiseDirectFrame.content().readableBytes());
|
||||
|
||||
byteBuf.writeByte(noiseDirectFrame.versionedFrameTypeByte());
|
||||
byteBuf.writeShort(noiseDirectFrame.content().readableBytes());
|
||||
byteBuf.writeBytes(noiseDirectFrame.content());
|
||||
return byteBuf;
|
||||
}
|
||||
|
||||
private NoiseDirectFrame deserialize(final ByteBuf byteBuf) throws Exception {
|
||||
final byte versionAndFrameByte = byteBuf.readByte();
|
||||
final int version = (versionAndFrameByte & 0xF0) >> 4;
|
||||
if (version != NoiseDirectFrame.VERSION) {
|
||||
throw new NoiseHandshakeException("Invalid NoiseDirect version: " + version);
|
||||
}
|
||||
final byte frameTypeBits = (byte) (versionAndFrameByte & 0x0F);
|
||||
final NoiseDirectFrame.FrameType frameType = switch (frameTypeBits) {
|
||||
case 1 -> NoiseDirectFrame.FrameType.NK_HANDSHAKE;
|
||||
case 2 -> NoiseDirectFrame.FrameType.IK_HANDSHAKE;
|
||||
case 3 -> NoiseDirectFrame.FrameType.DATA;
|
||||
case 4 -> NoiseDirectFrame.FrameType.CLOSE;
|
||||
default -> throw new NoiseHandshakeException("Invalid NoiseDirect frame type: " + frameTypeBits);
|
||||
};
|
||||
|
||||
final int length = Short.toUnsignedInt(byteBuf.readShort());
|
||||
if (length != byteBuf.readableBytes()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Payload length did not match remaining buffer, should have been guaranteed by a previous handler");
|
||||
}
|
||||
return new NoiseDirectFrame(frameType, byteBuf.readSlice(length));
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* Waits for a Handshake {@link NoiseDirectFrame} and then replaces itself with a {@link NoiseDirectDataFrameCodec} and
|
||||
* forwards the handshake frame along as a {@link NoiseHandshakeInit} message
|
||||
*/
|
||||
public class NoiseDirectHandshakeSelector extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame frame) {
|
||||
try {
|
||||
if (!(ctx.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress)) {
|
||||
throw new IOException("Could not determine remote address");
|
||||
}
|
||||
// We've received an inbound handshake frame. Pull the framing-protocol specific data the downstream handler
|
||||
// needs into a NoiseHandshakeInit message and forward that along
|
||||
final NoiseHandshakeInit handshakeMessage = new NoiseHandshakeInit(inetSocketAddress.getAddress(),
|
||||
switch (frame.frameType()) {
|
||||
case DATA -> throw new NoiseHandshakeException("First message must have handshake frame type");
|
||||
case CLOSE -> throw new IllegalStateException("Close frames should not reach handshake selector");
|
||||
case IK_HANDSHAKE -> HandshakePattern.IK;
|
||||
case NK_HANDSHAKE -> HandshakePattern.NK;
|
||||
}, frame.content());
|
||||
|
||||
// Subsequent inbound messages and outbound should be data type frames or close frames. Inbound data frames
|
||||
// should be unwrapped and forwarded to the noise handler, outbound buffers should be wrapped and forwarded
|
||||
// for network serialization. Note that we need to install the Data frame handler before firing the read,
|
||||
// because we may receive an outbound message from the noiseHandler
|
||||
ctx.pipeline().replace(ctx.name(), null, new NoiseDirectDataFrameCodec());
|
||||
ctx.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
|
||||
/**
|
||||
* Watches for inbound close frames and closes the connection in response
|
||||
*/
|
||||
public class NoiseDirectInboundCloseHandler extends ChannelInboundHandlerAdapter {
|
||||
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "clientClose");
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
try {
|
||||
final NoiseDirectProtos.CloseReason closeReason = NoiseDirectProtos.CloseReason
|
||||
.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
|
||||
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "reason", closeReason.getCode().name()).increment();
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufOutputStream;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
/**
|
||||
* Translates {@link OutboundCloseErrorMessage}s into {@link NoiseDirectFrame} error frames. After error frames are
|
||||
* written, the channel is closed
|
||||
*/
|
||||
class NoiseDirectOutboundErrorHandler extends ChannelOutboundHandlerAdapter {
|
||||
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(NoiseDirectInboundCloseHandler.class, "serverClose");
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
final NoiseDirectProtos.CloseReason.Code code = switch (err.code()) {
|
||||
case SERVER_CLOSED -> NoiseDirectProtos.CloseReason.Code.UNAVAILABLE;
|
||||
case NOISE_ERROR -> NoiseDirectProtos.CloseReason.Code.ENCRYPTION_ERROR;
|
||||
case NOISE_HANDSHAKE_ERROR -> NoiseDirectProtos.CloseReason.Code.HANDSHAKE_ERROR;
|
||||
case INTERNAL_SERVER_ERROR -> NoiseDirectProtos.CloseReason.Code.INTERNAL_ERROR;
|
||||
};
|
||||
Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "reason", code.name()).increment();
|
||||
|
||||
final NoiseDirectProtos.CloseReason proto = NoiseDirectProtos.CloseReason.newBuilder()
|
||||
.setCode(code)
|
||||
.setMessage(err.message())
|
||||
.build();
|
||||
final ByteBuf byteBuf = ctx.alloc().buffer(proto.getSerializedSize());
|
||||
proto.writeTo(new ByteBufOutputStream(byteBuf));
|
||||
ctx.writeAndFlush(new NoiseDirectFrame(NoiseDirectFrame.FrameType.CLOSE, byteBuf))
|
||||
.addListener(ChannelFutureListener.CLOSE);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.ServerSocketChannel;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ErrorHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.EstablishLocalGrpcConnectionHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HAProxyMessageHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeHandler;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyProtocolDetectionHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A NoiseDirectTunnelServer accepts traffic from the public internet (in the form of Noise packets framed by a custom
|
||||
* binary framing protocol) and passes it through to a local gRPC server.
|
||||
*/
|
||||
public class NoiseDirectTunnelServer implements Managed {
|
||||
|
||||
private final ServerBootstrap bootstrap;
|
||||
private ServerSocketChannel channel;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseDirectTunnelServer.class);
|
||||
|
||||
public NoiseDirectTunnelServer(final int port,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress) {
|
||||
|
||||
this.bootstrap = new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(NioServerSocketChannel.class)
|
||||
.localAddress(port)
|
||||
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel socketChannel) {
|
||||
socketChannel.pipeline()
|
||||
.addLast(new ProxyProtocolDetectionHandler())
|
||||
.addLast(new HAProxyMessageHandler())
|
||||
// frame byte followed by a 2-byte length field
|
||||
.addLast(new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2))
|
||||
// Parses NoiseDirectFrames from wire bytes and vice versa
|
||||
.addLast(new NoiseDirectFrameCodec())
|
||||
// Terminate the connection if the client sends us a close frame
|
||||
.addLast(new NoiseDirectInboundCloseHandler())
|
||||
// Turn generic OutboundCloseErrorMessages into noise direct error frames
|
||||
.addLast(new NoiseDirectOutboundErrorHandler())
|
||||
// Forwards the first payload supplemented with handshake metadata, and then replaces itself with a
|
||||
// NoiseDirectDataFrameCodec to handle subsequent data frames
|
||||
.addLast(new NoiseDirectHandshakeSelector())
|
||||
// Performs the noise handshake and then replace itself with a NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||
grpcClientConnectionManager,
|
||||
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
|
||||
FramingType.NOISE_DIRECT))
|
||||
.addLast(new ErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public InetSocketAddress getLocalAddress() {
|
||||
return channel.localAddress();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws InterruptedException {
|
||||
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws InterruptedException {
|
||||
if (channel != null) {
|
||||
channel.close().await();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
|
||||
enum ApplicationWebSocketCloseReason {
|
||||
NOISE_HANDSHAKE_ERROR(4001),
|
||||
NOISE_ENCRYPTION_ERROR(4002);
|
||||
|
||||
private final int statusCode;
|
||||
|
||||
ApplicationWebSocketCloseReason(final int statusCode) {
|
||||
this.statusCode = statusCode;
|
||||
}
|
||||
|
||||
public int getStatusCode() {
|
||||
return statusCode;
|
||||
}
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.channel.socket.ServerSocketChannel;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.HttpServerCodec;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.handler.ssl.ClientAuth;
|
||||
import io.netty.handler.ssl.OpenSsl;
|
||||
import io.netty.handler.ssl.SslContext;
|
||||
import io.netty.handler.ssl.SslContextBuilder;
|
||||
import io.netty.handler.ssl.SslProtocols;
|
||||
import io.netty.handler.ssl.SslProvider;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.concurrent.Executor;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.net.ssl.SSLException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.*;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
/**
|
||||
* A Noise-over-WebSocket tunnel server accepts traffic from the public internet (in the form of Noise packets framed by
|
||||
* binary WebSocket frames) and passes it through to a local gRPC server.
|
||||
*/
|
||||
public class NoiseWebSocketTunnelServer implements Managed {
|
||||
|
||||
private final ServerBootstrap bootstrap;
|
||||
private ServerSocketChannel channel;
|
||||
|
||||
static final String AUTHENTICATED_SERVICE_PATH = "/authenticated";
|
||||
static final String ANONYMOUS_SERVICE_PATH = "/anonymous";
|
||||
static final String HEALTH_CHECK_PATH = "/health-check";
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseWebSocketTunnelServer.class);
|
||||
|
||||
public NoiseWebSocketTunnelServer(final int websocketPort,
|
||||
@Nullable final X509Certificate[] tlsCertificateChain,
|
||||
@Nullable final PrivateKey tlsPrivateKey,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair ecKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws SSLException {
|
||||
|
||||
@Nullable final SslContext sslContext;
|
||||
|
||||
if (tlsCertificateChain != null && tlsPrivateKey != null) {
|
||||
final SslProvider sslProvider;
|
||||
|
||||
if (OpenSsl.isAvailable()) {
|
||||
log.info("Native OpenSSL provider is available; will use native provider");
|
||||
sslProvider = SslProvider.OPENSSL;
|
||||
} else {
|
||||
log.info("No native SSL provider available; will use JDK provider");
|
||||
sslProvider = SslProvider.JDK;
|
||||
}
|
||||
|
||||
sslContext = SslContextBuilder.forServer(tlsPrivateKey, tlsCertificateChain)
|
||||
.clientAuth(ClientAuth.NONE)
|
||||
// Some load balancers require TLS 1.2 for health checks
|
||||
.protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2)
|
||||
.sslProvider(sslProvider)
|
||||
.build();
|
||||
} else {
|
||||
log.warn("No TLS credentials provided; Noise-over-WebSocket tunnel will not use TLS. This configuration is not suitable for production environments.");
|
||||
sslContext = null;
|
||||
}
|
||||
|
||||
this.bootstrap = new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(NioServerSocketChannel.class)
|
||||
.localAddress(websocketPort)
|
||||
.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel socketChannel) {
|
||||
socketChannel.pipeline()
|
||||
.addLast(new ProxyProtocolDetectionHandler())
|
||||
.addLast(new HAProxyMessageHandler());
|
||||
|
||||
if (sslContext != null) {
|
||||
socketChannel.pipeline().addLast(sslContext.newHandler(socketChannel.alloc(), delegatedTaskExecutor));
|
||||
}
|
||||
|
||||
socketChannel.pipeline()
|
||||
.addLast(new HttpServerCodec())
|
||||
.addLast(new HttpObjectAggregator(Noise.MAX_PACKET_LEN))
|
||||
// The WebSocket opening handshake handler will remove itself from the pipeline once it has received a valid WebSocket upgrade
|
||||
// request and passed it down the pipeline
|
||||
.addLast(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_SERVICE_PATH, ANONYMOUS_SERVICE_PATH, HEALTH_CHECK_PATH))
|
||||
.addLast(new WebSocketServerProtocolHandler("/", true))
|
||||
// Metrics on inbound/outbound Close frames
|
||||
.addLast(new WebSocketCloseMetricHandler())
|
||||
// Turn generic OutboundCloseErrorMessages into websocket close frames
|
||||
.addLast(new WebSocketOutboundErrorHandler())
|
||||
.addLast(new RejectUnsupportedMessagesHandler())
|
||||
.addLast(new WebsocketPayloadCodec())
|
||||
// The WebSocket handshake complete listener will forward the first payload supplemented with
|
||||
// data from the websocket handshake completion event, and then remove itself from the pipeline
|
||||
.addLast(new WebsocketHandshakeCompleteHandler(recognizedProxySecret))
|
||||
// The NoiseHandshakeHandler will perform the noise handshake and then replace itself with a
|
||||
// NoiseHandler
|
||||
.addLast(new NoiseHandshakeHandler(clientPublicKeysManager, ecKeyPair))
|
||||
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
|
||||
// once the Noise handshake has completed
|
||||
.addLast(new EstablishLocalGrpcConnectionHandler(
|
||||
grpcClientConnectionManager,
|
||||
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
|
||||
FramingType.WEBSOCKET))
|
||||
.addLast(new ErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
InetSocketAddress getLocalAddress() {
|
||||
return channel.localAddress();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws InterruptedException {
|
||||
channel = (ServerSocketChannel) bootstrap.bind().await().channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws InterruptedException {
|
||||
if (channel != null) {
|
||||
channel.close().await();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
/**
|
||||
* A "reject unsupported message" handler closes the channel if it receives messages it does not know how to process.
|
||||
*/
|
||||
public class RejectUnsupportedMessagesHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof final WebSocketFrame webSocketFrame) {
|
||||
if (webSocketFrame instanceof final TextWebSocketFrame textWebSocketFrame) {
|
||||
try {
|
||||
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
|
||||
} finally {
|
||||
textWebSocketFrame.release();
|
||||
}
|
||||
} else {
|
||||
// Allow all other types of WebSocket frames
|
||||
context.fireChannelRead(webSocketFrame);
|
||||
}
|
||||
} else {
|
||||
// Discard anything that's not a WebSocket frame
|
||||
ReferenceCountUtil.release(message);
|
||||
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INVALID_MESSAGE_TYPE));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
public class WebSocketCloseMetricHandler extends ChannelDuplexHandler {
|
||||
|
||||
private static String CLIENT_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "clientClose");
|
||||
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketCloseMetricHandler.class, "serverClose");
|
||||
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
|
||||
if (msg instanceof CloseWebSocketFrame closeFrame) {
|
||||
Metrics.counter(CLIENT_CLOSE_COUNTER_NAME, "closeCode", validatedCloseCode(closeFrame.statusCode())).increment();
|
||||
}
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof CloseWebSocketFrame closeFrame) {
|
||||
Metrics.counter(SERVER_CLOSE_COUNTER_NAME, "closeCode", Integer.toString(closeFrame.statusCode())).increment();
|
||||
}
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
|
||||
private static String validatedCloseCode(int closeCode) {
|
||||
|
||||
if (closeCode >= 1000 && closeCode <= 1015) {
|
||||
// RFC-6455 pre-defined status codes
|
||||
return Integer.toString(closeCode);
|
||||
} else if (closeCode >= 4000 && closeCode <= 4100) {
|
||||
// Application status codes
|
||||
return Integer.toString(closeCode);
|
||||
} else {
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaderValues;
|
||||
import io.netty.handler.codec.http.HttpMethod;
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
/**
|
||||
* A WebSocket opening handshake handler serves as the "front door" for the WebSocket/Noise tunnel and gracefully
|
||||
* rejects requests for anything other than a WebSocket connection to a known endpoint.
|
||||
*/
|
||||
class WebSocketOpeningHandshakeHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final String authenticatedPath;
|
||||
private final String anonymousPath;
|
||||
private final String healthCheckPath;
|
||||
|
||||
WebSocketOpeningHandshakeHandler(final String authenticatedPath,
|
||||
final String anonymousPath,
|
||||
final String healthCheckPath) {
|
||||
|
||||
this.authenticatedPath = authenticatedPath;
|
||||
this.anonymousPath = anonymousPath;
|
||||
this.healthCheckPath = healthCheckPath;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof FullHttpRequest request) {
|
||||
boolean shouldReleaseRequest = true;
|
||||
|
||||
try {
|
||||
if (request.decoderResult().isSuccess()) {
|
||||
if (HttpMethod.GET.equals(request.method())) {
|
||||
if (authenticatedPath.equals(request.uri()) || anonymousPath.equals(request.uri())) {
|
||||
if (request.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) {
|
||||
// Pass the request along to the websocket handshake handler and remove ourselves from the pipeline
|
||||
shouldReleaseRequest = false;
|
||||
|
||||
context.fireChannelRead(request);
|
||||
context.pipeline().remove(this);
|
||||
} else {
|
||||
closeConnectionWithStatus(context, request, HttpResponseStatus.UPGRADE_REQUIRED);
|
||||
}
|
||||
} else if (healthCheckPath.equals(request.uri())) {
|
||||
closeConnectionWithStatus(context, request, HttpResponseStatus.NO_CONTENT);
|
||||
} else {
|
||||
closeConnectionWithStatus(context, request, HttpResponseStatus.NOT_FOUND);
|
||||
}
|
||||
} else {
|
||||
closeConnectionWithStatus(context, request, HttpResponseStatus.METHOD_NOT_ALLOWED);
|
||||
}
|
||||
} else {
|
||||
closeConnectionWithStatus(context, request, HttpResponseStatus.BAD_REQUEST);
|
||||
}
|
||||
} finally {
|
||||
if (shouldReleaseRequest) {
|
||||
request.release();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Anything except HTTP requests should have been filtered out of the pipeline by now; treat this as an error
|
||||
ReferenceCountUtil.release(message);
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
private static void closeConnectionWithStatus(final ChannelHandlerContext context,
|
||||
final FullHttpRequest request,
|
||||
final HttpResponseStatus status) {
|
||||
|
||||
context.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), status))
|
||||
.addListener(ChannelFutureListener.CLOSE);
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
|
||||
/**
|
||||
* Converts {@link OutboundCloseErrorMessage}s written to the pipeline into WebSocket close frames
|
||||
*/
|
||||
class WebSocketOutboundErrorHandler extends ChannelDuplexHandler {
|
||||
private static String SERVER_CLOSE_COUNTER_NAME = MetricsUtil.name(WebSocketOutboundErrorHandler.class, "serverClose");
|
||||
|
||||
private boolean websocketHandshakeComplete = false;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebSocketOutboundErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) throws Exception {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
|
||||
setWebsocketHandshakeComplete();
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
protected void setWebsocketHandshakeComplete() {
|
||||
this.websocketHandshakeComplete = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof OutboundCloseErrorMessage err) {
|
||||
if (websocketHandshakeComplete) {
|
||||
final int status = switch (err.code()) {
|
||||
case SERVER_CLOSED -> WebSocketCloseStatus.SERVICE_RESTART.code();
|
||||
case NOISE_ERROR -> ApplicationWebSocketCloseReason.NOISE_ENCRYPTION_ERROR.getStatusCode();
|
||||
case NOISE_HANDSHAKE_ERROR -> ApplicationWebSocketCloseReason.NOISE_HANDSHAKE_ERROR.getStatusCode();
|
||||
case INTERNAL_SERVER_ERROR -> WebSocketCloseStatus.INTERNAL_SERVER_ERROR.code();
|
||||
};
|
||||
ctx.write(new CloseWebSocketFrame(new WebSocketCloseStatus(status, err.message())), promise)
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
} else {
|
||||
log.debug("Error {} occurred before websocket handshake complete", err);
|
||||
// We haven't completed a websocket handshake, so we can't really communicate errors in a semantically-meaningful
|
||||
// way; just close the connection instead.
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.MessageDigest;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
|
||||
/**
|
||||
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
|
||||
* Noise handshake handler for the requested path.
|
||||
*/
|
||||
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final byte[] recognizedProxySecret;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
|
||||
|
||||
@VisibleForTesting
|
||||
static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
|
||||
|
||||
@VisibleForTesting
|
||||
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
|
||||
|
||||
private InetAddress remoteAddress = null;
|
||||
private HandshakePattern handshakePattern = null;
|
||||
|
||||
WebsocketHandshakeCompleteHandler(final String recognizedProxySecret) {
|
||||
|
||||
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
|
||||
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
|
||||
// MessageDigest.equals() later.
|
||||
this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
final Optional<InetAddress> maybePreferredRemoteAddress =
|
||||
getPreferredRemoteAddress(context, handshakeCompleteEvent);
|
||||
|
||||
if (maybePreferredRemoteAddress.isEmpty()) {
|
||||
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
|
||||
"Could not determine remote address"))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
remoteAddress = maybePreferredRemoteAddress.get();
|
||||
handshakePattern = switch (handshakeCompleteEvent.requestUri()) {
|
||||
case NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH -> HandshakePattern.IK;
|
||||
case NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH -> HandshakePattern.NK;
|
||||
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
|
||||
// internal error if something slipped through.
|
||||
default -> throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
|
||||
};
|
||||
}
|
||||
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object msg) {
|
||||
try {
|
||||
if (!(msg instanceof ByteBuf frame)) {
|
||||
throw new IllegalStateException("Unexpected msg type: " + msg.getClass());
|
||||
}
|
||||
|
||||
if (handshakePattern == null || remoteAddress == null) {
|
||||
throw new IllegalStateException("Received payload before websocket handshake complete");
|
||||
}
|
||||
|
||||
final NoiseHandshakeInit handshakeMessage =
|
||||
new NoiseHandshakeInit(remoteAddress, handshakePattern, frame);
|
||||
|
||||
context.pipeline().remove(WebsocketHandshakeCompleteHandler.class);
|
||||
context.fireChannelRead(handshakeMessage);
|
||||
} catch (Exception e) {
|
||||
ReferenceCountUtil.release(msg);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
|
||||
|
||||
final byte[] recognizedProxySecretFromHeader =
|
||||
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
|
||||
.getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
|
||||
|
||||
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
|
||||
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
|
||||
|
||||
return getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
|
||||
try {
|
||||
return InetAddresses.forString(mostRecentProxy);
|
||||
} catch (final IllegalArgumentException e) {
|
||||
log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Either we don't trust the forwarded-for header or it's not present
|
||||
if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
|
||||
return Optional.of(inetSocketAddress.getAddress());
|
||||
} else {
|
||||
log.warn("Channel's remote address was not an InetSocketAddress");
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
|
||||
*
|
||||
* @param forwardedFor the value of an X-Forwarded-For header
|
||||
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
|
||||
* {@code forwardedFor} was null
|
||||
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
|
||||
* MDN</a>
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
|
||||
return Optional.ofNullable(forwardedFor)
|
||||
.map(ff -> {
|
||||
final int idx = forwardedFor.lastIndexOf(',') + 1;
|
||||
return idx < forwardedFor.length()
|
||||
? forwardedFor.substring(idx).trim()
|
||||
: null;
|
||||
})
|
||||
.filter(StringUtils::isNotBlank);
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
|
||||
/**
|
||||
* Extracts buffers from inbound BinaryWebsocketFrames before forwarding to a
|
||||
* {@link org.whispersystems.textsecuregcm.grpc.net.NoiseHandler} for decryption and wraps outbound encrypted noise
|
||||
* packet buffers in BinaryWebsocketFrames for writing through the websocket layer.
|
||||
*/
|
||||
public class WebsocketPayloadCodec extends ChannelDuplexHandler {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
if (msg instanceof BinaryWebSocketFrame frame) {
|
||||
ctx.fireChannelRead(frame.content());
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
ctx.write(new BinaryWebSocketFrame(bb), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,11 +15,7 @@ import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.time.Clock;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
@@ -43,7 +39,6 @@ import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
|
||||
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
|
||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
|
||||
@@ -267,9 +262,8 @@ record CommandDependencies(
|
||||
() -> dynamicConfigurationManager.getConfiguration().getSvrbStatusCodesToIgnoreForAccountDeletion());
|
||||
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
|
||||
storageServiceExecutor, retryExecutor, configuration.getSecureStorageServiceConfiguration());
|
||||
GrpcClientConnectionManager grpcClientConnectionManager = new GrpcClientConnectionManager();
|
||||
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient,
|
||||
grpcClientConnectionManager, disconnectionRequestListenerExecutor, retryExecutor);
|
||||
disconnectionRequestListenerExecutor, retryExecutor);
|
||||
MessagesCache messagesCache = new MessagesCache(messagesCluster,
|
||||
messageDeliveryScheduler, messageDeletionExecutor, retryExecutor, Clock.systemUTC(), experimentEnrollmentManager);
|
||||
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, retryExecutor, asyncCdnS3Client,
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
option java_package = "org.whispersystems.textsecuregcm.grpc.net.noisedirect";
|
||||
option java_outer_classname = "NoiseDirectProtos";
|
||||
|
||||
message CloseReason {
|
||||
enum Code {
|
||||
UNSPECIFIED = 0;
|
||||
// Indicates non-error termination
|
||||
// Examples:
|
||||
// - The client is finished with the connection
|
||||
OK = 1;
|
||||
|
||||
// There was an issue with the handshake. If sent after a handshake response,
|
||||
// the response includes more information about the nature of the error
|
||||
// Examples:
|
||||
// - The client did not provide a handshake message
|
||||
// - The client had incorrect authentication credentials. The handshake
|
||||
// payload includes additional details
|
||||
HANDSHAKE_ERROR = 2;
|
||||
|
||||
// There was an encryption/decryption issue after the handshake
|
||||
// Examples:
|
||||
// - The client incorrectly encrypted a noise message and it had a bad
|
||||
// AEAD tag
|
||||
ENCRYPTION_ERROR = 3;
|
||||
|
||||
// The server is temporarily unavailable, going away, or requires a
|
||||
// connection reset
|
||||
// Examples:
|
||||
// - The server is shutting down
|
||||
// - The client’s authentication credentials have been rotated
|
||||
UNAVAILABLE = 4;
|
||||
|
||||
// There was an an internal error
|
||||
// Examples:
|
||||
// - The server experienced a temporary database outage that prevented it
|
||||
// from checking the client's credentials
|
||||
INTERNAL_ERROR = 5;
|
||||
}
|
||||
|
||||
Code code = 1;
|
||||
|
||||
// If present, includes details about the error. Implementations should never
|
||||
// parse or otherwise implement conditional logic based on the contents of the
|
||||
// error message string, it is for logging and debugging purposes only.
|
||||
string message = 2;
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
option java_package = "org.whispersystems.textsecuregcm.grpc.net";
|
||||
option java_outer_classname = "NoiseTunnelProtos";
|
||||
|
||||
message HandshakeInit {
|
||||
string user_agent = 1;
|
||||
|
||||
// An Accept-Language as described in
|
||||
// https://httpwg.org/specs/rfc9110.html#field.accept-language
|
||||
string accept_language = 2;
|
||||
|
||||
// A UUID serialized as 16 bytes (big end first). Must be unset (empty) for an
|
||||
// unauthenticated handshake
|
||||
bytes aci = 3;
|
||||
|
||||
// The deviceId, 0 < deviceId < 128. Must be unset for an unauthenticated
|
||||
// handshake
|
||||
uint32 device_id = 4;
|
||||
|
||||
// The first bytes of the application request byte stream, may contain less
|
||||
// than a full request
|
||||
bytes fast_open_request = 5;
|
||||
}
|
||||
|
||||
message HandshakeResponse {
|
||||
enum Code {
|
||||
UNSPECIFIED = 0;
|
||||
|
||||
// The noise session may be used to send application layer requests
|
||||
OK = 1;
|
||||
|
||||
// The provided client static key did not match the registered public key
|
||||
// for the provided aci/deviceId.
|
||||
WRONG_PUBLIC_KEY = 2;
|
||||
|
||||
// The client version is to old, it should be upgraded before retrying
|
||||
DEPRECATED = 3;
|
||||
}
|
||||
|
||||
// The handshake outcome
|
||||
Code code = 1;
|
||||
|
||||
// Additional information about an error status, for debugging only
|
||||
string error_details = 2;
|
||||
|
||||
// An optional response to a fast_open_request provided in the HandshakeInit.
|
||||
// Note that a response may not be present even if a fast_open_request was
|
||||
// present. If so, the response will be returned in a later message.
|
||||
bytes fast_open_response = 3;
|
||||
}
|
||||
@@ -20,8 +20,6 @@ import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
@@ -30,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
||||
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
|
||||
class DisconnectionRequestManagerTest {
|
||||
|
||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
private DisconnectionRequestManager disconnectionRequestManager;
|
||||
|
||||
@RegisterExtension
|
||||
@@ -38,10 +35,8 @@ class DisconnectionRequestManagerTest {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
|
||||
|
||||
disconnectionRequestManager = new DisconnectionRequestManager(REDIS_EXTENSION.getRedisClient(),
|
||||
grpcClientConnectionManager,
|
||||
Runnable::run,
|
||||
mock(ScheduledExecutorService.class));
|
||||
|
||||
@@ -103,16 +98,8 @@ class DisconnectionRequestManagerTest {
|
||||
|
||||
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
|
||||
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
|
||||
verify(grpcClientConnectionManager, timeout(1_000))
|
||||
.closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
|
||||
|
||||
verify(grpcClientConnectionManager, timeout(1_000))
|
||||
.closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
|
||||
|
||||
disconnectionRequestManager.requestDisconnection(otherAccountIdentifier, List.of(otherDeviceId));
|
||||
|
||||
verify(grpcClientConnectionManager, timeout(1_000))
|
||||
.closeConnection(new AuthenticatedDevice(otherAccountIdentifier, otherDeviceId));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -141,11 +128,5 @@ class DisconnectionRequestManagerTest {
|
||||
|
||||
verify(primaryDeviceListener, timeout(1_000)).handleDisconnectionRequest();
|
||||
verify(linkedDeviceListener, timeout(1_000)).handleDisconnectionRequest();
|
||||
|
||||
verify(grpcClientConnectionManager, timeout(1_000))
|
||||
.closeConnection(new AuthenticatedDevice(accountIdentifier, primaryDeviceId));
|
||||
|
||||
verify(grpcClientConnectionManager, timeout(1_000))
|
||||
.closeConnection(new AuthenticatedDevice(accountIdentifier, linkedDeviceId));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import java.io.IOException;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
abstract class AbstractAuthenticationInterceptorTest {
|
||||
|
||||
private static DefaultEventLoopGroup eventLoopGroup;
|
||||
|
||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
private Server server;
|
||||
private ManagedChannel managedChannel;
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws IOException {
|
||||
final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
|
||||
|
||||
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
|
||||
|
||||
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
|
||||
// sure that we're using local channels and addresses
|
||||
server = NettyServerBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalServerChannel.class)
|
||||
.bossEventLoopGroup(eventLoopGroup)
|
||||
.workerEventLoopGroup(eventLoopGroup)
|
||||
.intercept(getInterceptor())
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
|
||||
.channelType(LocalChannel.class)
|
||||
.eventLoopGroup(eventLoopGroup)
|
||||
.usePlaintext()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
managedChannel.shutdown();
|
||||
server.shutdown();
|
||||
}
|
||||
|
||||
protected abstract AbstractAuthenticationInterceptor getInterceptor();
|
||||
|
||||
protected GrpcClientConnectionManager getClientConnectionManager() {
|
||||
return grpcClientConnectionManager;
|
||||
}
|
||||
|
||||
protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
|
||||
return RequestAttributesGrpc.newBlockingStub(managedChannel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.CallCredentials;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Status;
|
||||
import java.util.concurrent.Executor;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
|
||||
public class BasicAuthCallCredentials extends CallCredentials {
|
||||
|
||||
private final String username;
|
||||
private final String password;
|
||||
|
||||
public BasicAuthCallCredentials(String username, String password) {
|
||||
this.username = username;
|
||||
this.password = password;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor,
|
||||
final MetadataApplier applier) {
|
||||
try {
|
||||
Metadata headers = new Metadata();
|
||||
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER),
|
||||
HeaderUtils.basicAuthHeader(username, password));
|
||||
applier.apply(headers);
|
||||
} catch (Exception e) {
|
||||
applier.fail(Status.UNAUTHENTICATED.withCause(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +1,64 @@
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.signal.chat.rpc.EchoRequest;
|
||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
|
||||
class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
|
||||
class ProhibitAuthenticationInterceptorTest {
|
||||
private Server server;
|
||||
private ManagedChannel channel;
|
||||
|
||||
@Override
|
||||
protected AbstractAuthenticationInterceptor getInterceptor() {
|
||||
return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
|
||||
.directExecutor()
|
||||
.intercept(new ProhibitAuthenticationInterceptor())
|
||||
.addService(new EchoServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
|
||||
.directExecutor()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
channel.shutdownNow();
|
||||
server.shutdownNow();
|
||||
channel.awaitTermination(5, TimeUnit.SECONDS);
|
||||
server.awaitTermination(5, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCall() throws ChannelNotFoundException {
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
|
||||
void hasAuth() {
|
||||
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc
|
||||
.newBlockingStub(channel)
|
||||
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
|
||||
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
|
||||
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
|
||||
() -> client.echo(EchoRequest.getDefaultInstance()));
|
||||
assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
|
||||
}
|
||||
|
||||
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
|
||||
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
|
||||
@Test
|
||||
void noAuth() {
|
||||
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
|
||||
assertDoesNotThrow(() -> client.echo(EchoRequest.getDefaultInstance()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,44 +1,102 @@
|
||||
package org.whispersystems.textsecuregcm.auth.grpc;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
import io.grpc.inprocess.InProcessChannelBuilder;
|
||||
import io.grpc.inprocess.InProcessServerBuilder;
|
||||
import java.time.Instant;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
|
||||
class RequireAuthenticationInterceptorTest {
|
||||
private Server server;
|
||||
private ManagedChannel channel;
|
||||
private AccountAuthenticator authenticator;
|
||||
|
||||
@Override
|
||||
protected AbstractAuthenticationInterceptor getInterceptor() {
|
||||
return new RequireAuthenticationInterceptor(getClientConnectionManager());
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
authenticator = mock(AccountAuthenticator.class);
|
||||
server = InProcessServerBuilder.forName("RequestAttributesInterceptorTest")
|
||||
.directExecutor()
|
||||
.intercept(new RequireAuthenticationInterceptor(authenticator))
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
|
||||
channel = InProcessChannelBuilder.forName("RequestAttributesInterceptorTest")
|
||||
.directExecutor()
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
channel.shutdownNow();
|
||||
server.shutdownNow();
|
||||
channel.awaitTermination(5, TimeUnit.SECONDS);
|
||||
server.awaitTermination(5, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCall() throws ChannelNotFoundException {
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
|
||||
void hasAuth() {
|
||||
final UUID aci = UUID.randomUUID();
|
||||
final byte deviceId = 2;
|
||||
when(authenticator.authenticate(eq(new BasicCredentials("test", "password"))))
|
||||
.thenReturn(Optional.of(
|
||||
new org.whispersystems.textsecuregcm.auth.AuthenticatedDevice(aci, deviceId, Instant.now())));
|
||||
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
|
||||
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
|
||||
.newBlockingStub(channel)
|
||||
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
|
||||
final GetAuthenticatedDeviceResponse authenticatedDevice = client.getAuthenticatedDevice(
|
||||
GetAuthenticatedDeviceRequest.getDefaultInstance());
|
||||
assertEquals(authenticatedDevice.getDeviceId(), deviceId);
|
||||
assertEquals(UUIDUtil.fromByteString(authenticatedDevice.getAccountIdentifier()), aci);
|
||||
}
|
||||
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
|
||||
@Test
|
||||
void badCredentials() {
|
||||
when(authenticator.authenticate(any())).thenReturn(Optional.empty());
|
||||
|
||||
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
|
||||
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
|
||||
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
|
||||
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
|
||||
.newBlockingStub(channel)
|
||||
.withCallCredentials(new BasicAuthCallCredentials("test", "password"));
|
||||
|
||||
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
|
||||
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
|
||||
() -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
|
||||
Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
|
||||
}
|
||||
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
|
||||
@Test
|
||||
void missingCredentials() {
|
||||
when(authenticator.authenticate(any())).thenReturn(Optional.empty());
|
||||
|
||||
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc.newBlockingStub(channel);
|
||||
|
||||
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
|
||||
() -> client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance()));
|
||||
Assert.assertEquals(e.getStatus().getCode(), Status.Code.UNAUTHENTICATED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.Status;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
|
||||
class ChannelShutdownInterceptorTest {
|
||||
|
||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
private ChannelShutdownInterceptor channelShutdownInterceptor;
|
||||
|
||||
private ServerCallHandler<String, String> nextCallHandler;
|
||||
|
||||
private static final Metadata HEADERS = new Metadata();
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
|
||||
channelShutdownInterceptor = new ChannelShutdownInterceptor(grpcClientConnectionManager);
|
||||
|
||||
//noinspection unchecked
|
||||
nextCallHandler = mock(ServerCallHandler.class);
|
||||
|
||||
//noinspection unchecked
|
||||
when(nextCallHandler.startCall(any(), any())).thenReturn(mock(ServerCall.Listener.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCallComplete() {
|
||||
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||
|
||||
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
|
||||
|
||||
final ServerCall.Listener<String> serverCallListener =
|
||||
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||
|
||||
serverCallListener.onComplete();
|
||||
|
||||
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
|
||||
verify(serverCall, never()).close(any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCallCancelled() {
|
||||
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||
|
||||
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(true);
|
||||
|
||||
final ServerCall.Listener<String> serverCallListener =
|
||||
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||
|
||||
serverCallListener.onCancel();
|
||||
|
||||
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||
verify(grpcClientConnectionManager).handleServerCallComplete(serverCall);
|
||||
verify(serverCall, never()).close(any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void interceptCallChannelClosing() {
|
||||
@SuppressWarnings("unchecked") final ServerCall<String, String> serverCall = mock(ServerCall.class);
|
||||
|
||||
when(grpcClientConnectionManager.handleServerCallStart(serverCall)).thenReturn(false);
|
||||
|
||||
channelShutdownInterceptor.interceptCall(serverCall, HEADERS, nextCallHandler);
|
||||
|
||||
verify(grpcClientConnectionManager).handleServerCallStart(serverCall);
|
||||
verify(grpcClientConnectionManager, never()).handleServerCallComplete(serverCall);
|
||||
verify(serverCall).close(eq(Status.UNAVAILABLE), any());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Metadata;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerCall;
|
||||
import io.grpc.ServerCallHandler;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import io.grpc.stub.MetadataUtils;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
|
||||
public class RequestAttributesInterceptorTest {
|
||||
|
||||
private static String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)";
|
||||
private Server server;
|
||||
private AtomicBoolean removeUserAgent;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
removeUserAgent = new AtomicBoolean(false);
|
||||
|
||||
server = NettyServerBuilder.forPort(0)
|
||||
.directExecutor()
|
||||
.intercept(new RequestAttributesInterceptor())
|
||||
// the grpc client always inserts a user-agent if we don't set one, so to test missing UAs we remove the header
|
||||
// on the server-side
|
||||
.intercept(new ServerInterceptor() {
|
||||
@Override
|
||||
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
|
||||
final Metadata headers, final ServerCallHandler<ReqT, RespT> next) {
|
||||
if (removeUserAgent.get()) {
|
||||
headers.removeAll(Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER));
|
||||
}
|
||||
return next.startCall(call, headers);
|
||||
}
|
||||
})
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.build()
|
||||
.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
server.shutdownNow();
|
||||
server.awaitTermination(1, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
private static List<Arguments> handleInvalidAcceptLanguage() {
|
||||
return List.of(
|
||||
Arguments.argumentSet("Null Accept-Language header", Optional.empty()),
|
||||
Arguments.argumentSet("Empty Accept-Language header", Optional.of("")),
|
||||
Arguments.argumentSet("Invalid Accept-Language header", Optional.of("This is not a valid language preference list")));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleInvalidAcceptLanguage(Optional<String> acceptLanguageHeader) throws Exception {
|
||||
final Metadata metadata = new Metadata();
|
||||
acceptLanguageHeader.ifPresent(h -> metadata
|
||||
.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), h));
|
||||
final GetRequestAttributesResponse response = getRequestAttributes(metadata);
|
||||
assertEquals(response.getAcceptableLanguagesCount(), 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleMissingUserAgent() throws InterruptedException {
|
||||
removeUserAgent.set(true);
|
||||
final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
|
||||
assertEquals("", response.getUserAgent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void allAttributes() throws InterruptedException {
|
||||
final Metadata metadata = new Metadata();
|
||||
metadata.put(Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER), "ja,en;q=0.4");
|
||||
metadata.put(Metadata.Key.of("x-forwarded-for", Metadata.ASCII_STRING_MARSHALLER), "127.0.0.3");
|
||||
final GetRequestAttributesResponse response = getRequestAttributes(metadata);
|
||||
|
||||
assertTrue(response.getUserAgent().contains(USER_AGENT));
|
||||
assertEquals("127.0.0.3", response.getRemoteAddress());
|
||||
assertEquals(2, response.getAcceptableLanguagesCount());
|
||||
assertEquals("ja", response.getAcceptableLanguages(0));
|
||||
assertEquals("en;q=0.4", response.getAcceptableLanguages(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void useSocketAddrIfHeaderMissing() throws InterruptedException {
|
||||
final GetRequestAttributesResponse response = getRequestAttributes(new Metadata());
|
||||
assertEquals("127.0.0.1", response.getRemoteAddress());
|
||||
}
|
||||
|
||||
private GetRequestAttributesResponse getRequestAttributes(Metadata metadata)
|
||||
throws InterruptedException {
|
||||
final ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", server.getPort())
|
||||
.directExecutor()
|
||||
.usePlaintext()
|
||||
.userAgent(USER_AGENT)
|
||||
.build();
|
||||
try {
|
||||
final RequestAttributesGrpc.RequestAttributesBlockingStub client = RequestAttributesGrpc
|
||||
.newBlockingStub(channel)
|
||||
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
|
||||
return client.getRequestAttributes(GetRequestAttributesRequest.getDefaultInstance());
|
||||
} finally {
|
||||
channel.shutdownNow();
|
||||
channel.awaitTermination(1, TimeUnit.SECONDS);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import io.netty.util.ResourceLeakDetector;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
|
||||
public abstract class AbstractLeakDetectionTest {
|
||||
|
||||
private static ResourceLeakDetector.Level originalResourceLeakDetectorLevel;
|
||||
|
||||
@BeforeAll
|
||||
static void setLeakDetectionLevel() {
|
||||
originalResourceLeakDetectorLevel = ResourceLeakDetector.getLevel();
|
||||
ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void restoreLeakDetectionLevel() {
|
||||
ResourceLeakDetector.setLevel(originalResourceLeakDetectorLevel);
|
||||
}
|
||||
}
|
||||
@@ -1,325 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.InetAddress;
|
||||
import java.net.UnknownHostException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.AEADBadTagException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
|
||||
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
protected ECKeyPair serverKeyPair;
|
||||
protected ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private NoiseHandshakeCompleteHandler noiseHandshakeCompleteHandler;
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
static final String USER_AGENT = "Test/User-Agent";
|
||||
static final String ACCEPT_LANGUAGE = "test-lang";
|
||||
static final InetAddress REMOTE_ADDRESS;
|
||||
static {
|
||||
try {
|
||||
REMOTE_ADDRESS = InetAddress.getByAddress(new byte[]{0,1,2,3});
|
||||
} catch (UnknownHostException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static class PongHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext ctx, final Object msg) {
|
||||
try {
|
||||
if (msg instanceof ByteBuf bb) {
|
||||
if (new String(ByteBufUtil.getBytes(bb)).equals("ping")) {
|
||||
ctx.writeAndFlush(Unpooled.wrappedBuffer("pong".getBytes()))
|
||||
.addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unexpected message: " + new String(ByteBufUtil.getBytes(bb)));
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unexpected message type: " + msg);
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoiseHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
@Nullable
|
||||
private NoiseIdentityDeterminedEvent handshakeCompleteEvent = null;
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof NoiseIdentityDeterminedEvent noiseIdentityDeterminedEvent) {
|
||||
handshakeCompleteEvent = noiseIdentityDeterminedEvent;
|
||||
context.pipeline().addAfter(context.name(), null, new PongHandler());
|
||||
context.pipeline().remove(NoiseHandshakeCompleteHandler.class);
|
||||
} else {
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public NoiseIdentityDeterminedEvent getHandshakeCompleteEvent() {
|
||||
return handshakeCompleteEvent;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
serverKeyPair = ECKeyPair.generate();
|
||||
noiseHandshakeCompleteHandler = new NoiseHandshakeCompleteHandler();
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
embeddedChannel = new EmbeddedChannel(
|
||||
new NoiseHandshakeHandler(clientPublicKeysManager, serverKeyPair),
|
||||
noiseHandshakeCompleteHandler);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
embeddedChannel.close();
|
||||
}
|
||||
|
||||
protected EmbeddedChannel getEmbeddedChannel() {
|
||||
return embeddedChannel;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected NoiseIdentityDeterminedEvent getNoiseHandshakeCompleteEvent() {
|
||||
return noiseHandshakeCompleteHandler.getHandshakeCompleteEvent();
|
||||
}
|
||||
|
||||
protected abstract CipherStatePair doHandshake() throws Throwable;
|
||||
|
||||
/**
|
||||
* Read a message from the embedded channel and deserialize it with the provided client cipher state. If there are no
|
||||
* waiting messages in the channel, return null.
|
||||
*/
|
||||
byte[] readNextPlaintext(final CipherStatePair clientCipherPair) throws ShortBufferException, BadPaddingException {
|
||||
final ByteBuf responseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
if (responseFrame == null) {
|
||||
return null;
|
||||
}
|
||||
final byte[] plaintext = new byte[responseFrame.readableBytes() - 16];
|
||||
final int read = clientCipherPair.getReceiver().decryptWithAd(null,
|
||||
ByteBufUtil.getBytes(responseFrame), 0,
|
||||
plaintext, 0,
|
||||
responseFrame.readableBytes());
|
||||
assertEquals(read, plaintext.length);
|
||||
return plaintext;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void handleInvalidInitialMessage() throws InterruptedException {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final ByteBuf content = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(new NoiseHandshakeInit(REMOTE_ADDRESS, HandshakePattern.IK, content)).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(NoiseHandshakeException.class, writeFuture.cause());
|
||||
assertEquals(0, content.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleMessagesAfterInitialHandshakeFailure() throws InterruptedException {
|
||||
final ByteBuf[] frames = new ByteBuf[7];
|
||||
|
||||
for (int i = 0; i < frames.length; i++) {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
frames[i] = Unpooled.wrappedBuffer(contentBytes);
|
||||
|
||||
embeddedChannel.writeOneInbound(frames[i]).await();
|
||||
}
|
||||
|
||||
for (final ByteBuf frame : frames) {
|
||||
assertEquals(0, frame.refCnt());
|
||||
}
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleNonByteBufBinaryFrame() throws Throwable {
|
||||
final byte[] contentBytes = new byte[17];
|
||||
ThreadLocalRandom.current().nextBytes(contentBytes);
|
||||
|
||||
final BinaryWebSocketFrame message = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(contentBytes));
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message).await();
|
||||
|
||||
assertFalse(writeFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, writeFuture.cause());
|
||||
assertEquals(0, message.refCnt());
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelRead() throws Throwable {
|
||||
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||
final byte[] plaintext = "ping".getBytes(StandardCharsets.UTF_8);
|
||||
final byte[] ciphertext = new byte[plaintext.length + clientCipherStatePair.getSender().getMACLength()];
|
||||
clientCipherStatePair.getSender().encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.length);
|
||||
|
||||
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(ciphertext);
|
||||
assertTrue(embeddedChannel.writeOneInbound(ciphertextFrame).await().isSuccess());
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
|
||||
final byte[] response = readNextPlaintext(clientCipherStatePair);
|
||||
assertArrayEquals("pong".getBytes(StandardCharsets.UTF_8), response);
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadBadCiphertext() throws Throwable {
|
||||
doHandshake();
|
||||
final byte[] bogusCiphertext = new byte[32];
|
||||
io.netty.util.internal.ThreadLocalRandom.current().nextBytes(bogusCiphertext);
|
||||
|
||||
final ByteBuf ciphertextFrame = Unpooled.wrappedBuffer(bogusCiphertext);
|
||||
final ChannelFuture readCiphertextFuture = embeddedChannel.writeOneInbound(ciphertextFrame).await();
|
||||
|
||||
assertEquals(0, ciphertextFrame.refCnt());
|
||||
assertFalse(readCiphertextFuture.isSuccess());
|
||||
assertInstanceOf(AEADBadTagException.class, readCiphertextFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void channelReadUnexpectedMessageType() throws Throwable {
|
||||
doHandshake();
|
||||
final ChannelFuture readFuture = embeddedChannel.writeOneInbound(new Object()).await();
|
||||
|
||||
assertFalse(readFuture.isSuccess());
|
||||
assertInstanceOf(IllegalArgumentException.class, readFuture.cause());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void write() throws Throwable {
|
||||
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||
final byte[] plaintext = "A plaintext message".getBytes(StandardCharsets.UTF_8);
|
||||
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(plaintext);
|
||||
|
||||
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||
assertTrue(writePlaintextFuture.await().isSuccess());
|
||||
assertEquals(0, plaintextBuffer.refCnt());
|
||||
|
||||
final ByteBuf ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(ciphertextFrame);
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
|
||||
ciphertextFrame.release();
|
||||
|
||||
final byte[] decryptedPlaintext = new byte[ciphertext.length - clientCipherStatePair.getReceiver().getMACLength()];
|
||||
clientCipherStatePair.getReceiver().decryptWithAd(null, ciphertext, 0, decryptedPlaintext, 0, ciphertext.length);
|
||||
|
||||
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void writeUnexpectedMessageType() throws Throwable {
|
||||
doHandshake();
|
||||
final Object unexpectedMessaged = new Object();
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.pipeline().writeAndFlush(unexpectedMessaged);
|
||||
assertTrue(writeFuture.await().isSuccess());
|
||||
|
||||
assertEquals(unexpectedMessaged, embeddedChannel.outboundMessages().poll());
|
||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
|
||||
void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
|
||||
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||
final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
|
||||
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));
|
||||
|
||||
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||
assertTrue(writePlaintextFuture.isSuccess());
|
||||
|
||||
final byte[] decryptedPlaintext = new byte[plaintextLength];
|
||||
int plaintextOffset = 0;
|
||||
ByteBuf ciphertextFrame;
|
||||
while ((ciphertextFrame = (ByteBuf) embeddedChannel.outboundMessages().poll()) != null) {
|
||||
assertTrue(ciphertextFrame.readableBytes() <= Noise.MAX_PACKET_LEN);
|
||||
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame);
|
||||
ciphertextFrame.release();
|
||||
plaintextOffset += clientCipherStatePair.getReceiver()
|
||||
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
|
||||
}
|
||||
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||
assertEquals(0, plaintextBuffer.refCnt());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writeHugeInboundMessage() throws Throwable {
|
||||
doHandshake();
|
||||
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(big));
|
||||
assertThrows(NoiseException.class, embeddedChannel::checkException);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void channelAttributes() throws Throwable {
|
||||
doHandshake();
|
||||
final NoiseIdentityDeterminedEvent event = getNoiseHandshakeCompleteEvent();
|
||||
assertEquals(REMOTE_ADDRESS, event.remoteAddress());
|
||||
assertEquals(USER_AGENT, event.userAgent());
|
||||
assertEquals(ACCEPT_LANGUAGE, event.acceptLanguage());
|
||||
}
|
||||
|
||||
protected NoiseTunnelProtos.HandshakeInit.Builder baseHandshakeInit() {
|
||||
return NoiseTunnelProtos.HandshakeInit.newBuilder()
|
||||
.setUserAgent(USER_AGENT)
|
||||
.setAcceptLanguage(ACCEPT_LANGUAGE);
|
||||
}
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyByte;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.signal.chat.rpc.EchoRequest;
|
||||
import org.signal.chat.rpc.EchoResponse;
|
||||
import org.signal.chat.rpc.EchoServiceGrpc;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
|
||||
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelShutdownInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
public abstract class AbstractNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private static NioEventLoopGroup nioEventLoopGroup;
|
||||
private static DefaultEventLoopGroup defaultEventLoopGroup;
|
||||
private static ExecutorService delegatedTaskExecutor;
|
||||
private static ExecutorService serverCallExecutor;
|
||||
|
||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
private ClientPublicKeysManager clientPublicKeysManager;
|
||||
|
||||
private ECKeyPair serverKeyPair;
|
||||
private ECKeyPair clientKeyPair;
|
||||
|
||||
private ManagedLocalGrpcServer authenticatedGrpcServer;
|
||||
private ManagedLocalGrpcServer anonymousGrpcServer;
|
||||
|
||||
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
|
||||
private static final byte DEVICE_ID = Device.PRIMARY_ID;
|
||||
|
||||
public static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
nioEventLoopGroup = new NioEventLoopGroup();
|
||||
defaultEventLoopGroup = new DefaultEventLoopGroup();
|
||||
delegatedTaskExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
serverCallExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
|
||||
clientKeyPair = ECKeyPair.generate();
|
||||
serverKeyPair = ECKeyPair.generate();
|
||||
|
||||
grpcClientConnectionManager = new GrpcClientConnectionManager();
|
||||
|
||||
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
|
||||
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("test-grpc-service-authenticated");
|
||||
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("test-grpc-service-anonymous");
|
||||
|
||||
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
serverBuilder
|
||||
.executor(serverCallExecutor)
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.addService(new EchoServiceImpl())
|
||||
.intercept(new ChannelShutdownInterceptor(grpcClientConnectionManager))
|
||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||
.intercept(new RequireAuthenticationInterceptor(grpcClientConnectionManager));
|
||||
}
|
||||
};
|
||||
|
||||
authenticatedGrpcServer.start();
|
||||
|
||||
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
|
||||
@Override
|
||||
protected void configureServer(final ServerBuilder<?> serverBuilder) {
|
||||
serverBuilder
|
||||
.executor(serverCallExecutor)
|
||||
.addService(new RequestAttributesServiceImpl())
|
||||
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
|
||||
.intercept(new ProhibitAuthenticationInterceptor(grpcClientConnectionManager));
|
||||
}
|
||||
};
|
||||
|
||||
anonymousGrpcServer.start();
|
||||
this.start(
|
||||
nioEventLoopGroup,
|
||||
delegatedTaskExecutor,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
authenticatedGrpcServerAddress, anonymousGrpcServerAddress,
|
||||
RECOGNIZED_PROXY_SECRET);
|
||||
}
|
||||
|
||||
|
||||
protected abstract void start(
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair serverKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws Exception;
|
||||
protected abstract void stop() throws Exception;
|
||||
protected abstract NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey);
|
||||
|
||||
public void assertClosedWith(final NoiseTunnelClient client, final CloseFrameEvent.CloseReason reason)
|
||||
throws ExecutionException, InterruptedException, TimeoutException {
|
||||
final CloseFrameEvent result = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
|
||||
assertEquals(reason, result.closeReason());
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
authenticatedGrpcServer.stop();
|
||||
anonymousGrpcServer.stop();
|
||||
this.stop();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void tearDownAfterAll() throws InterruptedException {
|
||||
nioEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||
defaultEventLoopGroup.shutdownGracefully(100, 100, TimeUnit.MILLISECONDS).await();
|
||||
|
||||
delegatedTaskExecutor.shutdown();
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
delegatedTaskExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||
|
||||
serverCallExecutor.shutdown();
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
serverCallExecutor.awaitTermination(1, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void connectAuthenticated(final boolean includeProxyMessage) throws InterruptedException {
|
||||
try (final NoiseTunnelClient client = authenticated()
|
||||
.setProxyMessageSupplier(proxyMessageSupplier(includeProxyMessage))
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final NoiseTunnelClient client = authenticated()
|
||||
.setServerPublicKey(ECKeyPair.generate().getPublicKey())
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedMismatchedClientPublicKey() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
|
||||
|
||||
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertEquals(
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedUnrecognizedDevice() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
when(clientPublicKeysManager.findPublicKey(ACCOUNT_IDENTIFIER, DEVICE_ID))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertEquals(
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY,
|
||||
client.getHandshakeEventFuture().get(1, TimeUnit.SECONDS).handshakeResponse().getCode());
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void clientNormalClosure() throws InterruptedException {
|
||||
final NoiseTunnelClient client = anonymous().build();
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
client.close();
|
||||
|
||||
// When we gracefully close the tunnel client, we should send an OK close frame
|
||||
final CloseFrameEvent closeFrame = client.closeFrameFuture().join();
|
||||
assertEquals(CloseFrameEvent.CloseInitiator.CLIENT, closeFrame.closeInitiator());
|
||||
assertEquals(CloseFrameEvent.CloseReason.OK, closeFrame.closeReason());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAnonymous() throws InterruptedException {
|
||||
try (final NoiseTunnelClient client = anonymous().build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertTrue(response.getAccountIdentifier().isEmpty());
|
||||
assertEquals(0, response.getDeviceId());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAnonymousBadServerKeySignature() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
|
||||
// Try to verify the server's public key with something other than the key with which it was signed
|
||||
try (final NoiseTunnelClient client = anonymous()
|
||||
.setServerPublicKey(ECKeyPair.generate().getPublicKey())
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected ManagedChannel buildManagedChannel(final LocalAddress localAddress) {
|
||||
return NettyChannelBuilder.forAddress(localAddress)
|
||||
.channelType(LocalChannel.class)
|
||||
.eventLoopGroup(defaultEventLoopGroup)
|
||||
.usePlaintext()
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void closeForReauthentication() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
|
||||
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||
|
||||
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||
final CloseFrameEvent closeEvent = client.closeFrameFuture().get(2, TimeUnit.SECONDS);
|
||||
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeEvent.closeReason());
|
||||
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeEvent.closeInitiator());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void waitForCallCompletion() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
try (final NoiseTunnelClient client = authenticated().build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
|
||||
|
||||
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
|
||||
assertEquals(DEVICE_ID, response.getDeviceId());
|
||||
|
||||
final CountDownLatch responseCountDownLatch = new CountDownLatch(1);
|
||||
|
||||
// Start an open-ended server call and leave it in a non-complete state
|
||||
final StreamObserver<EchoRequest> echoRequestStreamObserver = EchoServiceGrpc.newStub(channel).echoStream(
|
||||
new StreamObserver<>() {
|
||||
@Override
|
||||
public void onNext(final EchoResponse echoResponse) {
|
||||
responseCountDownLatch.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(final Throwable throwable) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCompleted() {
|
||||
}
|
||||
});
|
||||
|
||||
// Requests are transmitted asynchronously; it's possible that we'll issue the "close connection" request before
|
||||
// the request even starts. Make sure we've done at least one request/response pair to ensure that the call has
|
||||
// truly started before requesting connection closure.
|
||||
echoRequestStreamObserver.onNext(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build());
|
||||
assertTrue(responseCountDownLatch.await(1, TimeUnit.SECONDS));
|
||||
|
||||
grpcClientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
|
||||
try {
|
||||
client.closeFrameFuture().get(100, TimeUnit.MILLISECONDS);
|
||||
fail("Channel should not close until active requests have finished");
|
||||
} catch (TimeoutException e) {
|
||||
}
|
||||
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, () -> EchoServiceGrpc.newBlockingStub(channel)
|
||||
.echo(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("Test")).build()));
|
||||
|
||||
// Complete the open-ended server call
|
||||
echoRequestStreamObserver.onCompleted();
|
||||
|
||||
final CloseFrameEvent closeFrameEvent = client.closeFrameFuture().get(1, TimeUnit.SECONDS);
|
||||
assertEquals(CloseFrameEvent.CloseInitiator.SERVER, closeFrameEvent.closeInitiator());
|
||||
assertEquals(CloseFrameEvent.CloseReason.SERVER_CLOSED, closeFrameEvent.closeReason());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected NoiseTunnelClient.Builder anonymous() {
|
||||
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey());
|
||||
}
|
||||
|
||||
protected NoiseTunnelClient.Builder authenticated() {
|
||||
return clientBuilder(nioEventLoopGroup, serverKeyPair.getPublicKey())
|
||||
.setAuthenticated(clientKeyPair, ACCOUNT_IDENTIFIER, DEVICE_ID);
|
||||
}
|
||||
|
||||
private static Supplier<HAProxyMessage> proxyMessageSupplier(boolean includeProxyMesage) {
|
||||
return includeProxyMesage
|
||||
? () -> new HAProxyMessage(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443)
|
||||
: null;
|
||||
}
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.DefaultEventLoopGroup;
|
||||
import io.netty.channel.EventLoopGroup;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
class GrpcClientConnectionManagerTest {
|
||||
|
||||
private static EventLoopGroup eventLoopGroup;
|
||||
|
||||
private LocalChannel localChannel;
|
||||
private LocalChannel remoteChannel;
|
||||
|
||||
private LocalServerChannel localServerChannel;
|
||||
|
||||
private GrpcClientConnectionManager grpcClientConnectionManager;
|
||||
|
||||
@BeforeAll
|
||||
static void setUpBeforeAll() {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws InterruptedException {
|
||||
eventLoopGroup = new DefaultEventLoopGroup();
|
||||
|
||||
grpcClientConnectionManager = new GrpcClientConnectionManager();
|
||||
|
||||
// We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
|
||||
// local server to which we can open trivial local connections
|
||||
localServerChannel = (LocalServerChannel) new ServerBootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(LocalServerChannel.class)
|
||||
.childHandler(new ChannelInitializer<>() {
|
||||
@Override
|
||||
protected void initChannel(final Channel channel) {
|
||||
}
|
||||
})
|
||||
.bind(new LocalAddress("test-server"))
|
||||
.await()
|
||||
.channel();
|
||||
|
||||
final Bootstrap clientBootstrap = new Bootstrap()
|
||||
.group(eventLoopGroup)
|
||||
.channel(LocalChannel.class)
|
||||
.handler(new ChannelInitializer<>() {
|
||||
@Override
|
||||
protected void initChannel(final Channel ch) {
|
||||
}
|
||||
});
|
||||
|
||||
localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
|
||||
remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws InterruptedException {
|
||||
localChannel.close().await();
|
||||
remoteChannel.close().await();
|
||||
localServerChannel.close().await();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
static void tearDownAfterAll() throws InterruptedException {
|
||||
eventLoopGroup.shutdownGracefully().await();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
|
||||
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
|
||||
|
||||
assertEquals(maybeAuthenticatedDevice,
|
||||
grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
|
||||
}
|
||||
|
||||
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
|
||||
return List.of(
|
||||
Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
|
||||
Optional.empty()
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getRequestAttributes() {
|
||||
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
|
||||
|
||||
final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
|
||||
remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
|
||||
|
||||
assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
|
||||
}
|
||||
|
||||
@Test
|
||||
void closeConnection() throws InterruptedException, ChannelNotFoundException {
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
|
||||
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
|
||||
|
||||
assertTrue(remoteChannel.isOpen());
|
||||
|
||||
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
assertEquals(List.of(remoteChannel),
|
||||
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertThrows(ChannelNotFoundException.class,
|
||||
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
|
||||
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleHandshakeInitiatedRequestAttributes(final InetAddress preferredRemoteAddress,
|
||||
final String userAgentHeader,
|
||||
final String acceptLanguageHeader,
|
||||
final RequestAttributes expectedRequestAttributes) {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
|
||||
|
||||
GrpcClientConnectionManager.handleHandshakeInitiated(embeddedChannel,
|
||||
preferredRemoteAddress,
|
||||
userAgentHeader,
|
||||
acceptLanguageHeader);
|
||||
|
||||
assertEquals(expectedRequestAttributes,
|
||||
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
|
||||
}
|
||||
|
||||
private static List<Arguments> handleHandshakeInitiatedRequestAttributes() {
|
||||
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
|
||||
|
||||
return List.of(
|
||||
Arguments.argumentSet("Null User-Agent and Accept-Language headers",
|
||||
preferredRemoteAddress, null, null,
|
||||
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
|
||||
|
||||
Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
|
||||
preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
|
||||
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
|
||||
|
||||
Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
|
||||
preferredRemoteAddress, "Not a valid user-agent string", null,
|
||||
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
|
||||
|
||||
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
|
||||
preferredRemoteAddress, null, "ja,en;q=0.4",
|
||||
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
|
||||
|
||||
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
|
||||
preferredRemoteAddress, null, "This is not a valid language preference list",
|
||||
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
|
||||
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
|
||||
|
||||
assertThrows(ChannelNotFoundException.class,
|
||||
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
|
||||
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
|
||||
|
||||
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertThrows(ChannelNotFoundException.class,
|
||||
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
|
||||
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
|
||||
assertThrows(ChannelNotFoundException.class,
|
||||
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
|
||||
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
|
||||
|
||||
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
|
||||
remoteChannel.close().await();
|
||||
|
||||
assertThrows(ChannelNotFoundException.class,
|
||||
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.haproxy.HAProxyCommand;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
|
||||
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class HAProxyMessageHandlerTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel = new EmbeddedChannel(new HAProxyMessageHandler());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleHAProxyMessage() throws InterruptedException {
|
||||
final HAProxyMessage haProxyMessage = new HAProxyMessage(
|
||||
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
|
||||
"10.0.0.1", "10.0.0.2", 12345, 443);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(haProxyMessage);
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
assertEquals(0, haProxyMessage.refCnt());
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleNonHAProxyMessage() throws InterruptedException {
|
||||
final byte[] bytes = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(bytes);
|
||||
|
||||
final ByteBuf message = Unpooled.wrappedBuffer(bytes);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(message);
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertEquals(message, embeddedChannel.inboundMessages().poll());
|
||||
assertEquals(1, message.refCnt());
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
}
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import java.util.Optional;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class NoiseAnonymousHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Exception {
|
||||
return doHandshake(baseHandshakeInit().build().toByteArray());
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] requestPayload) throws Exception {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
|
||||
clientHandshakeState.start();
|
||||
|
||||
// Send initiator handshake message
|
||||
|
||||
// 32 byte key, request payload, 16 byte AEAD tag
|
||||
final int initiateHandshakeMessageLength = 32 + requestPayload.length + 16;
|
||||
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeMessageLength];
|
||||
assertEquals(
|
||||
initiateHandshakeMessageLength,
|
||||
clientHandshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length));
|
||||
final NoiseHandshakeInit message = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.NK,
|
||||
Unpooled.wrappedBuffer(initiateHandshakeMessage));
|
||||
assertTrue(embeddedChannel.writeOneInbound(message).await().isSuccess());
|
||||
assertEquals(0, message.refCnt());
|
||||
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// Read responder handshake message
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
final ByteBuf responderHandshakeFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(responderHandshakeFrame);
|
||||
final byte[] responderHandshakeBytes = ByteBufUtil.getBytes(responderHandshakeFrame);
|
||||
|
||||
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponse = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(NoiseTunnelProtos.HandshakeResponse.Code.OK)
|
||||
.build();
|
||||
|
||||
// ephemeral key, payload, AEAD tag
|
||||
assertEquals(32 + expectedHandshakeResponse.getSerializedSize() + 16, responderHandshakeBytes.length);
|
||||
|
||||
final byte[] handshakeResponsePlaintext = new byte[expectedHandshakeResponse.getSerializedSize()];
|
||||
assertEquals(expectedHandshakeResponse.getSerializedSize(),
|
||||
clientHandshakeState.readMessage(
|
||||
responderHandshakeBytes, 0, responderHandshakeBytes.length,
|
||||
handshakeResponsePlaintext, 0));
|
||||
|
||||
assertEquals(expectedHandshakeResponse, NoiseTunnelProtos.HandshakeResponse.parseFrom(handshakeResponsePlaintext));
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||
|
||||
return clientHandshakeState.split();
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeWithRequest() throws Exception {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final byte[] handshakePlaintext = baseHandshakeInit()
|
||||
.setFastOpenRequest(ByteString.copyFromUtf8("ping")).build()
|
||||
.toByteArray();
|
||||
|
||||
final CipherStatePair cipherStatePair = doHandshake(handshakePlaintext);
|
||||
final byte[] response = readNextPlaintext(cipherStatePair);
|
||||
assertArrayEquals(response, "pong".getBytes());
|
||||
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeNoRequest() throws ShortBufferException, BadPaddingException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final CipherStatePair cipherStatePair = assertDoesNotThrow(() -> doHandshake());
|
||||
assertNull(readNextPlaintext(cipherStatePair));
|
||||
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(Optional.empty(), REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.util.internal.EmptyArrays;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientTransportHandler;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
class NoiseAuthenticatedHandlerTest extends AbstractNoiseHandlerTest {
|
||||
|
||||
private final ECKeyPair clientKeyPair = ECKeyPair.generate();
|
||||
|
||||
@Override
|
||||
protected CipherStatePair doHandshake() throws Throwable {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
return doHandshake(identityPayload(accountIdentifier, deviceId));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeNoInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
assertNull(readNextPlaintext(doHandshake(identityPayload(accountIdentifier, deviceId))));
|
||||
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeWithInitialRequest() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(clientKeyPair.getPublicKey())));
|
||||
|
||||
final byte[] handshakeInit = identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||
.setFastOpenRequest(ByteString.copyFromUtf8("ping"))
|
||||
.build()
|
||||
.toByteArray();
|
||||
|
||||
final byte[] response = readNextPlaintext(doHandshake(handshakeInit));
|
||||
assertEquals(4, response.length);
|
||||
assertEquals("pong", new String(response));
|
||||
|
||||
assertEquals(
|
||||
new NoiseIdentityDeterminedEvent(
|
||||
Optional.of(new AuthenticatedDevice(accountIdentifier, deviceId)),
|
||||
REMOTE_ADDRESS, USER_AGENT, ACCEPT_LANGUAGE),
|
||||
getNoiseHandshakeCompleteEvent());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMissingIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(EmptyArrays.EMPTY_BYTES));
|
||||
|
||||
verifyNoInteractions(clientPublicKeysManager);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeMalformedIdentityInformation() {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
// no deviceId byte
|
||||
byte[] malformedIdentityPayload = UUIDUtil.toBytes(UUID.randomUUID());
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(malformedIdentityPayload));
|
||||
|
||||
verifyNoInteractions(clientPublicKeysManager);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakeUnrecognizedDevice() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
|
||||
|
||||
doHandshake(
|
||||
identityPayload(accountIdentifier, deviceId),
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
|
||||
assertNull(embeddedChannel.pipeline().get(NoiseClientTransportHandler.class),
|
||||
"Noise stream handler should not be added to pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleCompleteHandshakePublicKeyMismatch() throws Throwable {
|
||||
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(ECKeyPair.generate().getPublicKey())));
|
||||
|
||||
doHandshake(
|
||||
identityPayload(accountIdentifier, deviceId),
|
||||
NoiseTunnelProtos.HandshakeResponse.Code.WRONG_PUBLIC_KEY);
|
||||
|
||||
assertNull(getNoiseHandshakeCompleteEvent());
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class),
|
||||
"Handshake handler should not remove self from pipeline after failed handshake");
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleInvalidExtraWrites()
|
||||
throws NoSuchAlgorithmException, ShortBufferException, InterruptedException {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
assertNotNull(embeddedChannel.pipeline().get(NoiseHandshakeHandler.class));
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
|
||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||
|
||||
final CompletableFuture<Optional<ECPublicKey>> findPublicKeyFuture = new CompletableFuture<>();
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId)).thenReturn(findPublicKeyFuture);
|
||||
|
||||
final NoiseHandshakeInit handshakeInit = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.IK,
|
||||
Unpooled.wrappedBuffer(
|
||||
initiatorHandshakeMessage(clientHandshakeState, identityPayload(accountIdentifier, deviceId))));
|
||||
assertTrue(embeddedChannel.writeOneInbound(handshakeInit).await().isSuccess());
|
||||
|
||||
// While waiting for the public key, send another message
|
||||
final ChannelFuture f = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(new byte[0])).await();
|
||||
assertInstanceOf(IllegalArgumentException.class, f.exceptionNow());
|
||||
|
||||
findPublicKeyFuture.complete(Optional.of(clientKeyPair.getPublicKey()));
|
||||
embeddedChannel.runPendingTasks();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleOversizeHandshakeMessage() {
|
||||
final byte[] big = TestRandomUtil.nextBytes(Noise.MAX_PACKET_LEN + 1);
|
||||
ByteBuffer.wrap(big)
|
||||
.put(UUIDUtil.toBytes(UUID.randomUUID()))
|
||||
.put((byte) 0x01);
|
||||
assertThrows(NoiseHandshakeException.class, () -> doHandshake(big));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handleKeyLookupError() {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final byte deviceId = randomDeviceId();
|
||||
when(clientPublicKeysManager.findPublicKey(accountIdentifier, deviceId))
|
||||
.thenReturn(CompletableFuture.failedFuture(new IOException()));
|
||||
assertThrows(IOException.class, () -> doHandshake(identityPayload(accountIdentifier, deviceId)));
|
||||
}
|
||||
|
||||
private HandshakeState clientHandshakeState() throws NoSuchAlgorithmException {
|
||||
final HandshakeState clientHandshakeState =
|
||||
new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||
|
||||
clientHandshakeState.getLocalKeyPair().setPrivateKey(clientKeyPair.getPrivateKey().serialize(), 0);
|
||||
clientHandshakeState.getRemotePublicKey().setPublicKey(serverKeyPair.getPublicKey().getPublicKeyBytes(), 0);
|
||||
clientHandshakeState.start();
|
||||
return clientHandshakeState;
|
||||
}
|
||||
|
||||
private byte[] initiatorHandshakeMessage(final HandshakeState clientHandshakeState, final byte[] payload)
|
||||
throws ShortBufferException {
|
||||
// Ephemeral key, encrypted static key, AEAD tag, encrypted payload, AEAD tag
|
||||
final byte[] initiatorMessageBytes = new byte[32 + 32 + 16 + payload.length + 16];
|
||||
int written = clientHandshakeState.writeMessage(initiatorMessageBytes, 0, payload, 0, payload.length);
|
||||
assertEquals(written, initiatorMessageBytes.length);
|
||||
return initiatorMessageBytes;
|
||||
}
|
||||
|
||||
private byte[] readHandshakeResponse(final HandshakeState clientHandshakeState, final byte[] message)
|
||||
throws ShortBufferException, BadPaddingException {
|
||||
|
||||
// 32 byte ephemeral server key, 16 byte AEAD tag for encrypted payload
|
||||
final int expectedResponsePayloadLength = message.length - 32 - 16;
|
||||
final byte[] responsePayload = new byte[expectedResponsePayloadLength];
|
||||
final int responsePayloadLength = clientHandshakeState.readMessage(message, 0, message.length, responsePayload, 0);
|
||||
assertEquals(expectedResponsePayloadLength, responsePayloadLength);
|
||||
return responsePayload;
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] payload) throws Throwable {
|
||||
return doHandshake(payload, NoiseTunnelProtos.HandshakeResponse.Code.OK);
|
||||
}
|
||||
|
||||
private CipherStatePair doHandshake(final byte[] payload, final NoiseTunnelProtos.HandshakeResponse.Code expectedStatus) throws Throwable {
|
||||
final EmbeddedChannel embeddedChannel = getEmbeddedChannel();
|
||||
|
||||
final HandshakeState clientHandshakeState = clientHandshakeState();
|
||||
final byte[] initiatorMessage = initiatorHandshakeMessage(clientHandshakeState, payload);
|
||||
|
||||
final NoiseHandshakeInit initMessage = new NoiseHandshakeInit(
|
||||
REMOTE_ADDRESS,
|
||||
HandshakePattern.IK,
|
||||
Unpooled.wrappedBuffer(initiatorMessage));
|
||||
final ChannelFuture await = embeddedChannel.writeOneInbound(initMessage).await();
|
||||
assertEquals(0, initMessage.refCnt());
|
||||
if (!await.isSuccess() && expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||
throw await.cause();
|
||||
}
|
||||
|
||||
// The handshake handler makes an asynchronous call to get the stored public key for the client, then handles the
|
||||
// result on its event loop. Because this is an embedded channel, this all happens on the main thread (i.e. the same
|
||||
// thread as this test), and so we need to nudge things forward to actually process the "found credentials" callback
|
||||
// and issue a "handshake complete" event.
|
||||
embeddedChannel.runPendingTasks();
|
||||
|
||||
// rethrow if running the task caused an error, and the caller isn't expecting an error
|
||||
if (expectedStatus == NoiseTunnelProtos.HandshakeResponse.Code.OK) {
|
||||
embeddedChannel.checkException();
|
||||
}
|
||||
|
||||
assertFalse(embeddedChannel.outboundMessages().isEmpty());
|
||||
|
||||
final ByteBuf handshakeResponseFrame = (ByteBuf) embeddedChannel.outboundMessages().poll();
|
||||
assertNotNull(handshakeResponseFrame);
|
||||
final byte[] handshakeResponseCiphertextBytes = ByteBufUtil.getBytes(handshakeResponseFrame);
|
||||
|
||||
final NoiseTunnelProtos.HandshakeResponse expectedHandshakeResponsePlaintext = NoiseTunnelProtos.HandshakeResponse.newBuilder()
|
||||
.setCode(expectedStatus)
|
||||
.build();
|
||||
|
||||
final byte[] actualHandshakeResponsePlaintext =
|
||||
readHandshakeResponse(clientHandshakeState, handshakeResponseCiphertextBytes);
|
||||
|
||||
assertEquals(
|
||||
expectedHandshakeResponsePlaintext,
|
||||
NoiseTunnelProtos.HandshakeResponse.parseFrom(actualHandshakeResponsePlaintext));
|
||||
|
||||
final byte[] serverPublicKey = new byte[32];
|
||||
clientHandshakeState.getRemotePublicKey().getPublicKey(serverPublicKey, 0);
|
||||
assertArrayEquals(serverPublicKey, serverKeyPair.getPublicKey().getPublicKeyBytes());
|
||||
|
||||
return clientHandshakeState.split();
|
||||
}
|
||||
|
||||
private NoiseTunnelProtos.HandshakeInit.Builder identifiedHandshakeInit(final UUID accountIdentifier, final byte deviceId) {
|
||||
return baseHandshakeInit()
|
||||
.setAci(UUIDUtil.toByteString(accountIdentifier))
|
||||
.setDeviceId(deviceId);
|
||||
}
|
||||
|
||||
private byte[] identityPayload(final UUID accountIdentifier, final byte deviceId) {
|
||||
return identifiedHandshakeInit(accountIdentifier, deviceId)
|
||||
.build()
|
||||
.toByteArray();
|
||||
}
|
||||
|
||||
private static byte randomDeviceId() {
|
||||
return (byte) ThreadLocalRandom.current().nextInt(1, Device.MAXIMUM_DEVICE_ID + 1);
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatNoException;
|
||||
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseClientHandshakeHelper;
|
||||
|
||||
|
||||
public class NoiseHandshakeHelperTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), "pong".getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithRequestPayload(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, "ping".getBytes(StandardCharsets.UTF_8), new byte[0]);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(HandshakePattern.class)
|
||||
void testWithoutPayloads(final HandshakePattern pattern) throws ShortBufferException, NoiseHandshakeException {
|
||||
doHandshake(pattern, new byte[0], new byte[0]);
|
||||
}
|
||||
|
||||
void doHandshake(final HandshakePattern pattern, final byte[] requestPayload, final byte[] responsePayload) throws ShortBufferException, NoiseHandshakeException {
|
||||
final ECKeyPair serverKeyPair = ECKeyPair.generate();
|
||||
final ECKeyPair clientKeyPair = ECKeyPair.generate();
|
||||
|
||||
NoiseHandshakeHelper serverHelper = new NoiseHandshakeHelper(pattern, serverKeyPair);
|
||||
NoiseClientHandshakeHelper clientHelper = switch (pattern) {
|
||||
case IK -> NoiseClientHandshakeHelper.IK(serverKeyPair.getPublicKey(), clientKeyPair);
|
||||
case NK -> NoiseClientHandshakeHelper.NK(serverKeyPair.getPublicKey());
|
||||
};
|
||||
|
||||
final byte[] initiate = clientHelper.write(requestPayload);
|
||||
final ByteBuf actualRequestPayload = serverHelper.read(initiate);
|
||||
assertThat(ByteBufUtil.getBytes(actualRequestPayload)).isEqualTo(requestPayload);
|
||||
|
||||
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.WRITE_MESSAGE);
|
||||
|
||||
final byte[] respond = serverHelper.write(responsePayload);
|
||||
byte[] actualResponsePayload = clientHelper.read(respond);
|
||||
assertThat(actualResponsePayload).isEqualTo(responsePayload);
|
||||
|
||||
assertThat(serverHelper.getHandshakeState().getAction()).isEqualTo(HandshakeState.SPLIT);
|
||||
assertThatNoException().isThrownBy(() -> serverHelper.getHandshakeState().split());
|
||||
assertThatNoException().isThrownBy(() -> clientHelper.split());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import java.util.HexFormat;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class ProxyProtocolDetectionHandlerTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static final byte[] PROXY_V2_MESSAGE_BYTES =
|
||||
HexFormat.of().parseHex("0d0a0d0a000d0a515549540a2111000c0a0000010a000002303901bb");
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel = new EmbeddedChannel(new ProxyProtocolDetectionHandler());
|
||||
}
|
||||
|
||||
@Test
|
||||
void singlePacketProxyMessage() throws InterruptedException {
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES));
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiPacketProxyMessage() throws InterruptedException {
|
||||
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, 0,
|
||||
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
|
||||
|
||||
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(PROXY_V2_MESSAGE_BYTES, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
|
||||
PROXY_V2_MESSAGE_BYTES.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
|
||||
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
firstWriteFuture.await();
|
||||
secondWriteFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertInstanceOf(HAProxyMessage.class, embeddedChannel.inboundMessages().poll());
|
||||
}
|
||||
|
||||
@Test
|
||||
void singlePacketNonProxyMessage() throws InterruptedException {
|
||||
final byte[] nonProxyProtocolMessage = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
|
||||
|
||||
final ChannelFuture writeFuture = embeddedChannel.writeOneInbound(Unpooled.wrappedBuffer(nonProxyProtocolMessage));
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
writeFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
|
||||
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
|
||||
|
||||
assertInstanceOf(ByteBuf.class, inboundMessage);
|
||||
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
|
||||
}
|
||||
|
||||
@Test
|
||||
void multiPacketNonProxyMessage() throws InterruptedException {
|
||||
final byte[] nonProxyProtocolMessage = new byte[32];
|
||||
ThreadLocalRandom.current().nextBytes(nonProxyProtocolMessage);
|
||||
|
||||
final ChannelFuture firstWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(nonProxyProtocolMessage, 0,
|
||||
ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1));
|
||||
|
||||
final ChannelFuture secondWriteFuture = embeddedChannel.writeOneInbound(
|
||||
Unpooled.wrappedBuffer(nonProxyProtocolMessage, ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1,
|
||||
nonProxyProtocolMessage.length - (ProxyProtocolDetectionHandler.PROXY_MESSAGE_DETECTION_BYTES - 1)));
|
||||
|
||||
embeddedChannel.flushInbound();
|
||||
|
||||
firstWriteFuture.await();
|
||||
secondWriteFuture.await();
|
||||
|
||||
assertTrue(embeddedChannel.pipeline().toMap().isEmpty());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
|
||||
final Object inboundMessage = embeddedChannel.inboundMessages().poll();
|
||||
|
||||
assertInstanceOf(ByteBuf.class, inboundMessage);
|
||||
assertArrayEquals(nonProxyProtocolMessage, ByteBufUtil.getBytes((ByteBuf) inboundMessage));
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class ClientErrorHandler extends ChannelInboundHandlerAdapter {
|
||||
private static final Logger log = LoggerFactory.getLogger(ClientErrorHandler.class);
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(final ChannelHandlerContext context, final Throwable cause) {
|
||||
log.error("Caught inbound error in client; closing connection", cause);
|
||||
context.channel().close();
|
||||
}
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||
|
||||
public record CloseFrameEvent(CloseReason closeReason, CloseInitiator closeInitiator, String reason) {
|
||||
|
||||
public enum CloseReason {
|
||||
OK,
|
||||
SERVER_CLOSED,
|
||||
NOISE_ERROR,
|
||||
NOISE_HANDSHAKE_ERROR,
|
||||
INTERNAL_SERVER_ERROR,
|
||||
UNKNOWN
|
||||
}
|
||||
|
||||
public enum CloseInitiator {
|
||||
SERVER,
|
||||
CLIENT
|
||||
}
|
||||
|
||||
public static CloseFrameEvent fromWebsocketCloseFrame(
|
||||
CloseWebSocketFrame closeWebSocketFrame,
|
||||
CloseInitiator closeInitiator) {
|
||||
final CloseReason code = switch (closeWebSocketFrame.statusCode()) {
|
||||
case 4001 -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||
case 4002 -> CloseReason.NOISE_ERROR;
|
||||
case 1011 -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||
case 1012 -> CloseReason.SERVER_CLOSED;
|
||||
case 1000 -> CloseReason.OK;
|
||||
default -> CloseReason.UNKNOWN;
|
||||
};
|
||||
return new CloseFrameEvent(code, closeInitiator, closeWebSocketFrame.reasonText());
|
||||
}
|
||||
|
||||
public static CloseFrameEvent fromNoiseDirectCloseFrame(
|
||||
NoiseDirectProtos.CloseReason noiseDirectCloseReason,
|
||||
CloseInitiator closeInitiator) {
|
||||
final CloseReason code = switch (noiseDirectCloseReason.getCode()) {
|
||||
case OK -> CloseReason.OK;
|
||||
case HANDSHAKE_ERROR -> CloseReason.NOISE_HANDSHAKE_ERROR;
|
||||
case ENCRYPTION_ERROR -> CloseReason.NOISE_ERROR;
|
||||
case UNAVAILABLE -> CloseReason.SERVER_CLOSED;
|
||||
case INTERNAL_ERROR -> CloseReason.INTERNAL_SERVER_ERROR;
|
||||
case UNRECOGNIZED, UNSPECIFIED -> CloseReason.UNKNOWN;
|
||||
};
|
||||
return new CloseFrameEvent(code, closeInitiator, noiseDirectCloseReason.getMessage());
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.ProxyHandler;
|
||||
|
||||
/**
|
||||
* Handler that takes plaintext inbound messages from a gRPC client and forwards them over the noise tunnel to a remote
|
||||
* gRPC server.
|
||||
* <p>
|
||||
* This handler waits until the first gRPC client message is ready and then establishes a connection with the remote
|
||||
* gRPC server. It expects the provided remoteHandlerStack to emit a {@link ReadyForNoiseHandshakeEvent} when the remote
|
||||
* connection is ready for its first inbound payload, and to emit a {@link NoiseClientHandshakeCompleteEvent} when the
|
||||
* handshake is finished.
|
||||
*/
|
||||
class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final List<ChannelHandler> remoteHandlerStack;
|
||||
private final NoiseTunnelProtos.HandshakeInit handshakeInit;
|
||||
|
||||
private final SocketAddress remoteServerAddress;
|
||||
// If provided, will be sent with the payload in the noise handshake
|
||||
|
||||
private final List<Object> pendingReads = new ArrayList<>();
|
||||
|
||||
private static final String NOISE_HANDSHAKE_HANDLER_NAME = "noise-handshake";
|
||||
|
||||
EstablishRemoteConnectionHandler(
|
||||
final List<ChannelHandler> remoteHandlerStack,
|
||||
final SocketAddress remoteServerAddress,
|
||||
final NoiseTunnelProtos.HandshakeInit handshakeInit) {
|
||||
this.remoteHandlerStack = remoteHandlerStack;
|
||||
this.handshakeInit = handshakeInit;
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext localContext) {
|
||||
new Bootstrap()
|
||||
.channel(NioSocketChannel.class)
|
||||
.group(localContext.channel().eventLoop())
|
||||
.handler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(final SocketChannel channel) throws Exception {
|
||||
|
||||
for (ChannelHandler handler : remoteHandlerStack) {
|
||||
channel.pipeline().addLast(handler);
|
||||
}
|
||||
channel.pipeline()
|
||||
.addLast(NOISE_HANDSHAKE_HANDLER_NAME, new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext remoteContext, final Object event)
|
||||
throws Exception {
|
||||
switch (event) {
|
||||
case ReadyForNoiseHandshakeEvent ignored ->
|
||||
remoteContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeInit.toByteArray()))
|
||||
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
|
||||
case NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) -> {
|
||||
remoteContext.pipeline()
|
||||
.replace(NOISE_HANDSHAKE_HANDLER_NAME, null, new ProxyHandler(localContext.channel()));
|
||||
localContext.pipeline().addLast(new ProxyHandler(remoteContext.channel()));
|
||||
|
||||
// If there was a payload response on the handshake, write it back to our gRPC client
|
||||
if (!handshakeResponse.getFastOpenResponse().isEmpty()) {
|
||||
localContext.writeAndFlush(Unpooled.wrappedBuffer(handshakeResponse
|
||||
.getFastOpenResponse()
|
||||
.asReadOnlyByteBuffer()));
|
||||
}
|
||||
|
||||
// Forward any messages we got from our gRPC client, now will be proxied to the remote context
|
||||
pendingReads.forEach(localContext::fireChannelRead);
|
||||
pendingReads.clear();
|
||||
localContext.pipeline().remove(EstablishRemoteConnectionHandler.this);
|
||||
}
|
||||
default -> {
|
||||
}
|
||||
}
|
||||
super.userEventTriggered(remoteContext, event);
|
||||
}
|
||||
})
|
||||
.addLast(new ClientErrorHandler());
|
||||
}
|
||||
})
|
||||
.connect(remoteServerAddress)
|
||||
.addListener((ChannelFutureListener) future -> {
|
||||
if (future.isSuccess()) {
|
||||
// Close the local connection if the remote channel closes and vice versa
|
||||
future.channel().closeFuture().addListener(closeFuture -> localContext.channel().close());
|
||||
localContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
|
||||
} else {
|
||||
localContext.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
pendingReads.add(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
pendingReads.forEach(ReferenceCountUtil::release);
|
||||
pendingReads.clear();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
|
||||
public record FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest) {}
|
||||
@@ -1,28 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
class HAProxyMessageSender extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final Supplier<HAProxyMessage> messageSupplier;
|
||||
|
||||
HAProxyMessageSender(final Supplier<HAProxyMessage> messageSupplier) {
|
||||
this.messageSupplier = messageSupplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext context) {
|
||||
if (context.channel().isActive()) {
|
||||
context.writeAndFlush(messageSupplier.get());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(final ChannelHandlerContext context) {
|
||||
context.writeAndFlush(messageSupplier.get());
|
||||
context.fireChannelActive();
|
||||
}
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandler;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* The noise tunnel streams bytes out of a gRPC client through noise and to a remote server. The server supports a "fast
|
||||
* open" optimization where the client can send a request along with the noise handshake. There's no direct way to
|
||||
* extract the request boundaries from the gRPC client's byte-stream, so {@link Http2Buffering#handler()} provides an
|
||||
* inbound pipeline handler that will parse the byte-stream back into HTTP/2 frames and buffer the first request.
|
||||
* <p>
|
||||
* Once an entire request has been buffered, the handler will remove itself from the pipeline and emit a
|
||||
* {@link FastOpenRequestBufferedEvent}
|
||||
*/
|
||||
public class Http2Buffering {
|
||||
|
||||
/**
|
||||
* Create a pipeline handler that consumes serialized HTTP/2 ByteBufs and emits a fast-open request
|
||||
*/
|
||||
public static ChannelInboundHandler handler() {
|
||||
return new Http2PrefaceHandler();
|
||||
}
|
||||
|
||||
private Http2Buffering() {
|
||||
}
|
||||
|
||||
private static class Http2PrefaceHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc7540.html#section-3.5
|
||||
private static final byte[] HTTP2_PREFACE =
|
||||
HexFormat.of().parseHex("505249202a20485454502f322e300d0a0d0a534d0d0a0d0a");
|
||||
private final ByteBuf read = Unpooled.buffer(HTTP2_PREFACE.length, HTTP2_PREFACE.length);
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
if (message instanceof ByteBuf bb) {
|
||||
bb.readBytes(read);
|
||||
if (read.readableBytes() < HTTP2_PREFACE.length) {
|
||||
// Copied the message into the read buffer, but haven't yet got a full HTTP2 preface. Wait for more.
|
||||
return;
|
||||
}
|
||||
if (!Arrays.equals(read.array(), HTTP2_PREFACE)) {
|
||||
throw new IllegalStateException("HTTP/2 stream must start with HTTP/2 preface");
|
||||
}
|
||||
context.pipeline().replace(this, "http2frame1", new Http2LengthFieldFrameDecoder());
|
||||
context.pipeline().addAfter("http2frame1", "http2frame2", new Http2FrameDecoder());
|
||||
context.pipeline().addAfter("http2frame2", "http2frame3", new Http2FirstRequestHandler());
|
||||
context.fireChannelRead(bb);
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected message: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
ReferenceCountUtil.release(read);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private record Http2Frame(ByteBuf bytes, FrameType type, boolean endStream) {
|
||||
|
||||
private static final byte FLAG_END_STREAM = 0x01;
|
||||
|
||||
enum FrameType {
|
||||
SETTINGS,
|
||||
HEADERS,
|
||||
DATA,
|
||||
WINDOW_UPDATE,
|
||||
OTHER;
|
||||
|
||||
static FrameType fromSerializedType(final byte type) {
|
||||
return switch (type) {
|
||||
case 0x00 -> Http2Frame.FrameType.DATA;
|
||||
case 0x01 -> Http2Frame.FrameType.HEADERS;
|
||||
case 0x04 -> Http2Frame.FrameType.SETTINGS;
|
||||
case 0x08 -> Http2Frame.FrameType.WINDOW_UPDATE;
|
||||
default -> Http2Frame.FrameType.OTHER;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit ByteBuf of entire HTTP/2 frame
|
||||
*/
|
||||
private static class Http2LengthFieldFrameDecoder extends LengthFieldBasedFrameDecoder {
|
||||
|
||||
public Http2LengthFieldFrameDecoder() {
|
||||
// Frames are 3 bytes of length, 6 bytes of other header, and then length bytes of payload
|
||||
super(16 * 1024 * 1024, 0, 3, 6, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the serialized Http/2 frames into {@link Http2Frame} objects
|
||||
*/
|
||||
private static class Http2FrameDecoder extends ByteToMessageDecoder {
|
||||
|
||||
@Override
|
||||
protected void decode(final ChannelHandlerContext ctx, final ByteBuf in, final List<Object> out) throws Exception {
|
||||
// https://www.rfc-editor.org/rfc/rfc7540.html#section-4.1
|
||||
final Http2Frame.FrameType frameType = Http2Frame.FrameType.fromSerializedType(in.getByte(in.readerIndex() + 3));
|
||||
final boolean endStream = endStream(frameType, in.getByte(in.readerIndex() + 4));
|
||||
out.add(new Http2Frame(in.readBytes(in.readableBytes()), frameType, endStream));
|
||||
}
|
||||
|
||||
boolean endStream(Http2Frame.FrameType frameType, byte flags) {
|
||||
// A gRPC request are packed into HTTP/2 frames like:
|
||||
// HEADERS frame | DATA frame 1 (endStream=0) | ... | DATA frame N (endstream=1)
|
||||
//
|
||||
// Our goal is to get an entire request buffered, so as soon as we see a DATA frame with the end stream flag set
|
||||
// we have a whole request. Note that we could have pieces of multiple requests, but the only thing we care about
|
||||
// is having at least one complete request. In total, we can expect something like:
|
||||
// HTTP-preface | SETTINGS frame | Frames we don't care about ... | DATA (endstream=1)
|
||||
//
|
||||
// The connection isn't 'established' until the server has responded with their own SETTINGS frame with the ack
|
||||
// bit set, but HTTP/2 allows the client to send frames before getting the ACK.
|
||||
if (frameType == Http2Frame.FrameType.DATA) {
|
||||
return (flags & Http2Frame.FLAG_END_STREAM) == Http2Frame.FLAG_END_STREAM;
|
||||
}
|
||||
|
||||
// In theory, at least. Unfortunately, the java gRPC client always waits for the HTTP/2 handshake to complete
|
||||
// (which requires the server sending back the ack) before it actually sends any requests. So if we waited for a
|
||||
// DATA frame, it would never come. The gRPC-java implementation always at least sends a WINDOW_UPDATE, so we
|
||||
// might as well pack that in.
|
||||
return frameType == Http2Frame.FrameType.WINDOW_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Collect HTTP/2 frames until we get an entire "request" to send
|
||||
*/
|
||||
private static class Http2FirstRequestHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
final List<Http2Frame> pendingFrames = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) {
|
||||
if (message instanceof Http2Frame http2Frame) {
|
||||
if (pendingFrames.isEmpty() && http2Frame.type != Http2Frame.FrameType.SETTINGS) {
|
||||
throw new IllegalStateException(
|
||||
"HTTP/2 stream must start with HTTP/2 SETTINGS frame, got " + http2Frame.type);
|
||||
}
|
||||
pendingFrames.add(http2Frame);
|
||||
if (http2Frame.endStream) {
|
||||
// We have a whole "request", emit the first request event and remove the http2 buffering handlers
|
||||
final ByteBuf request = Unpooled.wrappedBuffer(Stream.concat(
|
||||
Stream.of(Unpooled.wrappedBuffer(Http2PrefaceHandler.HTTP2_PREFACE)),
|
||||
pendingFrames.stream().map(Http2Frame::bytes))
|
||||
.toArray(ByteBuf[]::new));
|
||||
pendingFrames.clear();
|
||||
context.pipeline().remove(Http2LengthFieldFrameDecoder.class);
|
||||
context.pipeline().remove(Http2FrameDecoder.class);
|
||||
context.pipeline().remove(this);
|
||||
context.fireUserEventTriggered(new FastOpenRequestBufferedEvent(request));
|
||||
}
|
||||
} else {
|
||||
throw new IllegalStateException("Unexpected message: " + message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
pendingFrames.forEach(frame -> ReferenceCountUtil.release(frame.bytes()));
|
||||
pendingFrames.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* A netty user event that indicates that the noise handshake finished successfully.
|
||||
*
|
||||
* @param fastResponse A response if the client included a request to send in the initiate handshake message payload and
|
||||
* the server included a payload in the handshake response.
|
||||
*/
|
||||
public record NoiseClientHandshakeCompleteEvent(NoiseTunnelProtos.HandshakeResponse handshakeResponse) {}
|
||||
@@ -1,62 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import java.util.Optional;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.OutboundCloseErrorMessage;
|
||||
|
||||
public class NoiseClientHandshakeHandler extends ChannelDuplexHandler {
|
||||
|
||||
private final NoiseClientHandshakeHelper handshakeHelper;
|
||||
|
||||
public NoiseClientHandshakeHandler(NoiseClientHandshakeHelper handshakeHelper) {
|
||||
this.handshakeHelper = handshakeHelper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof ByteBuf plaintextHandshakePayload) {
|
||||
final byte[] payloadBytes = ByteBufUtil.getBytes(plaintextHandshakePayload,
|
||||
plaintextHandshakePayload.readerIndex(), plaintextHandshakePayload.readableBytes(),
|
||||
false);
|
||||
final byte[] handshakeMessage = handshakeHelper.write(payloadBytes);
|
||||
ctx.write(Unpooled.wrappedBuffer(handshakeMessage), promise);
|
||||
} else {
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message)
|
||||
throws NoiseHandshakeException {
|
||||
if (message instanceof ByteBuf frame) {
|
||||
try {
|
||||
final byte[] payload = handshakeHelper.read(ByteBufUtil.getBytes(frame));
|
||||
final NoiseTunnelProtos.HandshakeResponse handshakeResponse =
|
||||
NoiseTunnelProtos.HandshakeResponse.parseFrom(payload);
|
||||
|
||||
context.pipeline().replace(this, null, new NoiseClientTransportHandler(handshakeHelper.split()));
|
||||
context.fireUserEventTriggered(new NoiseClientHandshakeCompleteEvent(handshakeResponse));
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new NoiseHandshakeException("Failed to parse handshake response");
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
} else {
|
||||
context.fireChannelRead(message);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
handshakeHelper.destroy();
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import com.southernstorm.noise.protocol.HandshakeState;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.ShortBufferException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeException;
|
||||
|
||||
public class NoiseClientHandshakeHelper {
|
||||
|
||||
private final HandshakePattern handshakePattern;
|
||||
private final HandshakeState handshakeState;
|
||||
|
||||
private NoiseClientHandshakeHelper(HandshakePattern handshakePattern, HandshakeState handshakeState) {
|
||||
this.handshakePattern = handshakePattern;
|
||||
this.handshakeState = handshakeState;
|
||||
}
|
||||
|
||||
public static NoiseClientHandshakeHelper IK(ECPublicKey serverStaticKey, ECKeyPair clientStaticKey) {
|
||||
try {
|
||||
final HandshakeState state = new HandshakeState(HandshakePattern.IK.protocol(), HandshakeState.INITIATOR);
|
||||
state.getLocalKeyPair().setPrivateKey(clientStaticKey.getPrivateKey().serialize(), 0);
|
||||
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
||||
state.start();
|
||||
return new NoiseClientHandshakeHelper(HandshakePattern.IK, state);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static NoiseClientHandshakeHelper NK(ECPublicKey serverStaticKey) {
|
||||
try {
|
||||
final HandshakeState state = new HandshakeState(HandshakePattern.NK.protocol(), HandshakeState.INITIATOR);
|
||||
state.getRemotePublicKey().setPublicKey(serverStaticKey.getPublicKeyBytes(), 0);
|
||||
state.start();
|
||||
return new NoiseClientHandshakeHelper(HandshakePattern.NK, state);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] write(final byte[] requestPayload) throws ShortBufferException {
|
||||
final byte[] initiateHandshakeMessage = new byte[initiateHandshakeKeysLength() + requestPayload.length + 16];
|
||||
handshakeState.writeMessage(initiateHandshakeMessage, 0, requestPayload, 0, requestPayload.length);
|
||||
return initiateHandshakeMessage;
|
||||
}
|
||||
|
||||
private int initiateHandshakeKeysLength() {
|
||||
return switch (handshakePattern) {
|
||||
// 32-byte ephemeral key, 32-byte encrypted static key, 16-byte AEAD tag
|
||||
case IK -> 32 + 32 + 16;
|
||||
// 32-byte ephemeral key
|
||||
case NK -> 32;
|
||||
};
|
||||
}
|
||||
|
||||
public byte[] read(final byte[] responderHandshakeMessage) throws NoiseHandshakeException {
|
||||
// Don't process additional messages if the handshake failed and we're just waiting to close
|
||||
if (handshakeState.getAction() != HandshakeState.READ_MESSAGE) {
|
||||
throw new NoiseHandshakeException("Received message with handshake state " + handshakeState.getAction());
|
||||
}
|
||||
final int payloadLength = responderHandshakeMessage.length - 16 - 32;
|
||||
final byte[] responsePayload = new byte[payloadLength];
|
||||
final int payloadBytesRead;
|
||||
try {
|
||||
payloadBytesRead = handshakeState
|
||||
.readMessage(responderHandshakeMessage, 0, responderHandshakeMessage.length, responsePayload, 0);
|
||||
if (payloadBytesRead != responsePayload.length) {
|
||||
throw new IllegalStateException(
|
||||
"Unexpected payload length, required " + payloadLength + " got " + payloadBytesRead);
|
||||
}
|
||||
return responsePayload;
|
||||
} catch (ShortBufferException e) {
|
||||
throw new IllegalStateException("Failed to deserialize payload of known length" + e.getMessage());
|
||||
} catch (BadPaddingException e) {
|
||||
throw new NoiseHandshakeException(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public CipherStatePair split() {
|
||||
return this.handshakeState.split();
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
this.handshakeState.destroy();
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.southernstorm.noise.protocol.CipherState;
|
||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||
|
||||
/**
|
||||
* A Noise transport handler manages a bidirectional Noise session after a handshake has completed.
|
||||
*/
|
||||
public class NoiseClientTransportHandler extends ChannelDuplexHandler {
|
||||
|
||||
private final CipherStatePair cipherStatePair;
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(NoiseClientTransportHandler.class);
|
||||
|
||||
NoiseClientTransportHandler(CipherStatePair cipherStatePair) {
|
||||
this.cipherStatePair = cipherStatePair;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
try {
|
||||
if (message instanceof ByteBuf frame) {
|
||||
final CipherState cipherState = cipherStatePair.getReceiver();
|
||||
|
||||
// We've read this frame off the wire, and so it's most likely a direct buffer that's not backed by an array.
|
||||
// We'll need to copy it to a heap buffer.
|
||||
final byte[] noiseBuffer = ByteBufUtil.getBytes(frame);
|
||||
|
||||
// Overwrite the ciphertext with the plaintext to avoid an extra allocation for a dedicated plaintext buffer
|
||||
final int plaintextLength = cipherState.decryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, noiseBuffer.length);
|
||||
|
||||
context.fireChannelRead(Unpooled.wrappedBuffer(noiseBuffer, 0, plaintextLength));
|
||||
} else {
|
||||
// Anything except binary frames should have been filtered out of the pipeline by now; treat this as an
|
||||
// error
|
||||
throw new IllegalArgumentException("Unexpected message in pipeline: " + message);
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.release(message);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (message instanceof ByteBuf plaintext) {
|
||||
try {
|
||||
final CipherState cipherState = cipherStatePair.getSender();
|
||||
final int plaintextLength = plaintext.readableBytes();
|
||||
|
||||
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
||||
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
||||
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
||||
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
||||
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
||||
|
||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||
|
||||
context.write(Unpooled.wrappedBuffer(noiseBuffer), promise);
|
||||
} finally {
|
||||
ReferenceCountUtil.release(plaintext);
|
||||
}
|
||||
} else {
|
||||
if (!(message instanceof CloseWebSocketFrame || message instanceof NoiseDirectFrame)) {
|
||||
// Clients only write ByteBufs or a close frame on errors, so any other message is unexpected
|
||||
log.warn("Unexpected object in pipeline: {}", message);
|
||||
}
|
||||
context.write(message, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(final ChannelHandlerContext context) {
|
||||
cipherStatePair.destroy();
|
||||
}
|
||||
}
|
||||
@@ -1,408 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.southernstorm.noise.protocol.Noise;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.local.LocalChannel;
|
||||
import io.netty.channel.local.LocalServerChannel;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import io.netty.handler.codec.MessageToMessageCodec;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessage;
|
||||
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpClientCodec;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpObjectAggregator;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
|
||||
import io.netty.handler.ssl.SslContextBuilder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.net.SocketAddress;
|
||||
import java.net.URI;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import javax.net.ssl.SSLException;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseTunnelProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrame;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectFrameCodec;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.noisedirect.NoiseDirectProtos;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.websocket.WebsocketPayloadCodec;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
public class NoiseTunnelClient implements AutoCloseable {
|
||||
|
||||
private final CompletableFuture<CloseFrameEvent> closeEventFuture;
|
||||
private final CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture;
|
||||
private final CompletableFuture<Void> userCloseFuture;
|
||||
private final ServerBootstrap serverBootstrap;
|
||||
private Channel serverChannel;
|
||||
|
||||
public static final URI AUTHENTICATED_WEBSOCKET_URI = URI.create("wss://localhost/authenticated");
|
||||
public static final URI ANONYMOUS_WEBSOCKET_URI = URI.create("wss://localhost/anonymous");
|
||||
|
||||
public static class Builder {
|
||||
|
||||
final SocketAddress remoteServerAddress;
|
||||
NioEventLoopGroup eventLoopGroup;
|
||||
ECPublicKey serverPublicKey;
|
||||
|
||||
FramingType framingType = FramingType.WEBSOCKET;
|
||||
URI websocketUri = ANONYMOUS_WEBSOCKET_URI;
|
||||
HttpHeaders headers = new DefaultHttpHeaders();
|
||||
NoiseTunnelProtos.HandshakeInit.Builder handshakeInit = NoiseTunnelProtos.HandshakeInit.newBuilder();
|
||||
|
||||
boolean authenticated = false;
|
||||
ECKeyPair ecKeyPair = null;
|
||||
boolean useTls;
|
||||
X509Certificate trustedServerCertificate = null;
|
||||
Supplier<HAProxyMessage> proxyMessageSupplier = null;
|
||||
|
||||
public Builder(
|
||||
final SocketAddress remoteServerAddress,
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final ECPublicKey serverPublicKey) {
|
||||
this.remoteServerAddress = remoteServerAddress;
|
||||
this.eventLoopGroup = eventLoopGroup;
|
||||
this.serverPublicKey = serverPublicKey;
|
||||
}
|
||||
|
||||
public Builder setAuthenticated(final ECKeyPair ecKeyPair, final UUID accountIdentifier, final byte deviceId) {
|
||||
this.authenticated = true;
|
||||
handshakeInit.setAci(UUIDUtil.toByteString(accountIdentifier));
|
||||
handshakeInit.setDeviceId(deviceId);
|
||||
this.ecKeyPair = ecKeyPair;
|
||||
this.websocketUri = AUTHENTICATED_WEBSOCKET_URI;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setWebsocketUri(final URI websocketUri) {
|
||||
this.websocketUri = websocketUri;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setUseTls(X509Certificate trustedServerCertificate) {
|
||||
this.useTls = true;
|
||||
this.trustedServerCertificate = trustedServerCertificate;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setProxyMessageSupplier(Supplier<HAProxyMessage> proxyMessageSupplier) {
|
||||
this.proxyMessageSupplier = proxyMessageSupplier;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setUserAgent(final String userAgent) {
|
||||
handshakeInit.setUserAgent(userAgent);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setAcceptLanguage(final String acceptLanguage) {
|
||||
handshakeInit.setAcceptLanguage(acceptLanguage);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHeaders(final HttpHeaders headers) {
|
||||
this.headers = headers;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setServerPublicKey(ECPublicKey serverPublicKey) {
|
||||
this.serverPublicKey = serverPublicKey;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setFramingType(FramingType framingType) {
|
||||
this.framingType = framingType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public NoiseTunnelClient build() {
|
||||
final List<ChannelHandler> handlers = new ArrayList<>();
|
||||
if (proxyMessageSupplier != null) {
|
||||
handlers.addAll(List.of(HAProxyMessageEncoder.INSTANCE, new HAProxyMessageSender(proxyMessageSupplier)));
|
||||
}
|
||||
if (useTls) {
|
||||
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
|
||||
|
||||
if (trustedServerCertificate != null) {
|
||||
sslContextBuilder.trustManager(trustedServerCertificate);
|
||||
}
|
||||
|
||||
try {
|
||||
handlers.add(sslContextBuilder.build().newHandler(ByteBufAllocator.DEFAULT));
|
||||
} catch (SSLException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
// handles the wrapping and unrwrapping the framing layer (websockets or noisedirect)
|
||||
handlers.addAll(switch (framingType) {
|
||||
case WEBSOCKET -> websocketHandlerStack(websocketUri, headers);
|
||||
case NOISE_DIRECT -> noiseDirectHandlerStack(authenticated);
|
||||
});
|
||||
|
||||
final NoiseClientHandshakeHelper helper = authenticated
|
||||
? NoiseClientHandshakeHelper.IK(serverPublicKey, ecKeyPair)
|
||||
: NoiseClientHandshakeHelper.NK(serverPublicKey);
|
||||
|
||||
handlers.add(new NoiseClientHandshakeHandler(helper));
|
||||
|
||||
// When the noise handshake completes we'll save the response from the server so client users can inspect it
|
||||
final UserEventFuture<NoiseClientHandshakeCompleteEvent> handshakeEventHandler =
|
||||
new UserEventFuture<>(NoiseClientHandshakeCompleteEvent.class);
|
||||
handlers.add(handshakeEventHandler);
|
||||
|
||||
// Whenever the framing layer sends or receives a close frame, it will emit a CloseFrameEvent and we'll save off
|
||||
// information about why the connection was closed.
|
||||
final UserEventFuture<CloseFrameEvent> closeEventHandler = new UserEventFuture<>(CloseFrameEvent.class);
|
||||
handlers.add(closeEventHandler);
|
||||
|
||||
// When the user closes the client, write a normal closure close frame
|
||||
final CompletableFuture<Void> userCloseFuture = new CompletableFuture<>();
|
||||
handlers.add(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void handlerAdded(final ChannelHandlerContext ctx) {
|
||||
userCloseFuture.thenRunAsync(() -> ctx.pipeline().writeAndFlush(switch (framingType) {
|
||||
case WEBSOCKET -> new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE);
|
||||
case NOISE_DIRECT -> new NoiseDirectFrame(
|
||||
NoiseDirectFrame.FrameType.CLOSE,
|
||||
Unpooled.wrappedBuffer(NoiseDirectProtos.CloseReason
|
||||
.newBuilder()
|
||||
.setCode(NoiseDirectProtos.CloseReason.Code.OK)
|
||||
.build()
|
||||
.toByteArray()));
|
||||
})
|
||||
.addListener(ChannelFutureListener.CLOSE),
|
||||
ctx.executor());
|
||||
}
|
||||
});
|
||||
|
||||
final NoiseTunnelClient client =
|
||||
new NoiseTunnelClient(eventLoopGroup, closeEventHandler.future, handshakeEventHandler.future, userCloseFuture, fastOpenRequest -> new EstablishRemoteConnectionHandler(
|
||||
handlers,
|
||||
remoteServerAddress,
|
||||
handshakeInit.setFastOpenRequest(ByteString.copyFrom(fastOpenRequest)).build()));
|
||||
client.start();
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
private NoiseTunnelClient(NioEventLoopGroup eventLoopGroup,
|
||||
CompletableFuture<CloseFrameEvent> closeEventFuture,
|
||||
CompletableFuture<NoiseClientHandshakeCompleteEvent> handshakeEventFuture,
|
||||
CompletableFuture<Void> userCloseFuture,
|
||||
Function<byte[], EstablishRemoteConnectionHandler> handler) {
|
||||
|
||||
this.userCloseFuture = userCloseFuture;
|
||||
this.closeEventFuture = closeEventFuture;
|
||||
this.handshakeEventFuture = handshakeEventFuture;
|
||||
this.serverBootstrap = new ServerBootstrap()
|
||||
.localAddress(new LocalAddress("websocket-noise-tunnel-client"))
|
||||
.channel(LocalServerChannel.class)
|
||||
.group(eventLoopGroup)
|
||||
.childHandler(new ChannelInitializer<LocalChannel>() {
|
||||
@Override
|
||||
protected void initChannel(final LocalChannel localChannel) {
|
||||
localChannel.pipeline()
|
||||
// We just get a bytestream out of the gRPC client, but we need to pull out the first "request" from the
|
||||
// stream to do a "fast-open" request. So we buffer HTTP/2 frames until we get a whole "request" to put
|
||||
// in the handshake.
|
||||
.addLast(Http2Buffering.handler())
|
||||
// Once we have a complete request we'll get an event and after bytes will start flowing as-is again. At
|
||||
// that point we can pass everything off to the EstablishRemoteConnectionHandler which will actually
|
||||
// connect to the remote service
|
||||
.addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
|
||||
if (evt instanceof FastOpenRequestBufferedEvent(ByteBuf fastOpenRequest)) {
|
||||
byte[] fastOpenRequestBytes = ByteBufUtil.getBytes(fastOpenRequest);
|
||||
fastOpenRequest.release();
|
||||
ctx.pipeline().addLast(handler.apply(fastOpenRequestBytes));
|
||||
}
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
})
|
||||
.addLast(new ClientErrorHandler());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private static class UserEventFuture<T> extends ChannelInboundHandlerAdapter {
|
||||
private final CompletableFuture<T> future = new CompletableFuture<>();
|
||||
private final Class<T> cls;
|
||||
|
||||
UserEventFuture(Class<T> cls) {
|
||||
this.cls = cls;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) {
|
||||
if (cls.isInstance(evt)) {
|
||||
future.complete((T) evt);
|
||||
}
|
||||
ctx.fireUserEventTriggered(evt);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public LocalAddress getLocalAddress() {
|
||||
return (LocalAddress) serverChannel.localAddress();
|
||||
}
|
||||
|
||||
private NoiseTunnelClient start() {
|
||||
serverChannel = serverBootstrap.bind().awaitUninterruptibly().channel();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws InterruptedException {
|
||||
userCloseFuture.complete(null);
|
||||
serverChannel.close().await();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A future that completes when a close frame is observed
|
||||
*/
|
||||
public CompletableFuture<CloseFrameEvent> closeFrameFuture() {
|
||||
return closeEventFuture;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A future that completes when the noise handshake finishes
|
||||
*/
|
||||
public CompletableFuture<NoiseClientHandshakeCompleteEvent> getHandshakeEventFuture() {
|
||||
return handshakeEventFuture;
|
||||
}
|
||||
|
||||
|
||||
private static List<ChannelHandler> noiseDirectHandlerStack(boolean authenticated) {
|
||||
return List.of(
|
||||
new LengthFieldBasedFrameDecoder(Noise.MAX_PACKET_LEN, 1, 2),
|
||||
new NoiseDirectFrameCodec(),
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx) {
|
||||
ctx.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
|
||||
ctx.fireChannelActive();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
try {
|
||||
final NoiseDirectProtos.CloseReason closeReason =
|
||||
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
ctx.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromNoiseDirectCloseFrame(closeReason, CloseFrameEvent.CloseInitiator.SERVER));
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
} else {
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||
if (msg instanceof NoiseDirectFrame ndf && ndf.frameType() == NoiseDirectFrame.FrameType.CLOSE) {
|
||||
final NoiseDirectProtos.CloseReason errorPayload =
|
||||
NoiseDirectProtos.CloseReason.parseFrom(ByteBufUtil.getBytes(ndf.content()));
|
||||
ctx.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromNoiseDirectCloseFrame(errorPayload, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||
}
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
},
|
||||
new MessageToMessageCodec<NoiseDirectFrame, ByteBuf>() {
|
||||
boolean noiseHandshakeFinished = false;
|
||||
|
||||
@Override
|
||||
protected void encode(final ChannelHandlerContext ctx, final ByteBuf msg, final List<Object> out) {
|
||||
final NoiseDirectFrame.FrameType frameType = noiseHandshakeFinished
|
||||
? NoiseDirectFrame.FrameType.DATA
|
||||
: (authenticated ? NoiseDirectFrame.FrameType.IK_HANDSHAKE : NoiseDirectFrame.FrameType.NK_HANDSHAKE);
|
||||
noiseHandshakeFinished = true;
|
||||
out.add(new NoiseDirectFrame(frameType, msg.retain()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(final ChannelHandlerContext ctx, final NoiseDirectFrame msg,
|
||||
final List<Object> out) {
|
||||
out.add(msg.content().retain());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private static List<ChannelHandler> websocketHandlerStack(final URI websocketUri, final HttpHeaders headers) {
|
||||
return List.of(
|
||||
new HttpClientCodec(),
|
||||
new HttpObjectAggregator(Noise.MAX_PACKET_LEN),
|
||||
// Inbound CloseWebSocketFrame messages wil get "eaten" by the WebSocketClientProtocolHandler, so if we
|
||||
// want to react to them on our own, we need to catch them before they hit that handler.
|
||||
new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void channelRead(final ChannelHandlerContext context, final Object message) throws Exception {
|
||||
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||
context.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.SERVER));
|
||||
}
|
||||
|
||||
super.channelRead(context, message);
|
||||
}
|
||||
},
|
||||
new WebSocketClientProtocolHandler(websocketUri,
|
||||
WebSocketVersion.V13,
|
||||
null,
|
||||
false,
|
||||
headers,
|
||||
Noise.MAX_PACKET_LEN,
|
||||
10_000),
|
||||
new ChannelOutboundHandlerAdapter() {
|
||||
@Override
|
||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception {
|
||||
if (message instanceof CloseWebSocketFrame closeWebSocketFrame) {
|
||||
context.fireUserEventTriggered(
|
||||
CloseFrameEvent.fromWebsocketCloseFrame(closeWebSocketFrame, CloseFrameEvent.CloseInitiator.CLIENT));
|
||||
}
|
||||
super.write(context, message, promise);
|
||||
}
|
||||
},
|
||||
new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
if (event instanceof WebSocketClientProtocolHandler.ClientHandshakeStateEvent clientHandshakeStateEvent) {
|
||||
if (clientHandshakeStateEvent == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
|
||||
context.fireUserEventTriggered(new ReadyForNoiseHandshakeEvent());
|
||||
}
|
||||
}
|
||||
context.fireUserEventTriggered(event);
|
||||
}
|
||||
},
|
||||
new WebsocketPayloadCodec());
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.client;
|
||||
|
||||
public record ReadyForNoiseHandshakeEvent() {
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.noisedirect;
|
||||
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.FramingType;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
import java.util.concurrent.Executor;
|
||||
|
||||
class DirectNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||
private NoiseDirectTunnelServer noiseDirectTunnelServer;
|
||||
|
||||
@Override
|
||||
protected void start(
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair serverKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws Exception {
|
||||
|
||||
noiseDirectTunnelServer = new NoiseDirectTunnelServer(0,
|
||||
eventLoopGroup,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress);
|
||||
noiseDirectTunnelServer.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void stop() throws InterruptedException {
|
||||
noiseDirectTunnelServer.stop();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
|
||||
return new NoiseTunnelClient
|
||||
.Builder(noiseDirectTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
|
||||
.setFramingType(FramingType.NOISE_DIRECT);
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||
|
||||
class RejectUnsupportedMessagesHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel = new EmbeddedChannel(new RejectUnsupportedMessagesHandler());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void allowWebSocketFrame(final WebSocketFrame frame) {
|
||||
embeddedChannel.writeOneInbound(frame);
|
||||
|
||||
try {
|
||||
assertEquals(frame, embeddedChannel.inboundMessages().poll());
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
assertEquals(1, frame.refCnt());
|
||||
} finally {
|
||||
frame.release();
|
||||
}
|
||||
}
|
||||
|
||||
private static List<WebSocketFrame> allowWebSocketFrame() {
|
||||
return List.of(
|
||||
new BinaryWebSocketFrame(),
|
||||
new CloseWebSocketFrame(),
|
||||
new ContinuationWebSocketFrame(),
|
||||
new PingWebSocketFrame(),
|
||||
new PongWebSocketFrame());
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectTextFrame() {
|
||||
final TextWebSocketFrame textFrame = new TextWebSocketFrame();
|
||||
embeddedChannel.writeOneInbound(textFrame);
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
assertEquals(0, textFrame.refCnt());
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectNonWebSocketFrame() {
|
||||
final ByteBuf bytes = Unpooled.buffer(0);
|
||||
embeddedChannel.writeOneInbound(bytes);
|
||||
|
||||
assertTrue(embeddedChannel.inboundMessages().isEmpty());
|
||||
assertEquals(0, bytes.refCnt());
|
||||
}
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.Status;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.KeyFactory;
|
||||
import java.security.KeyStore;
|
||||
import java.security.PrivateKey;
|
||||
import java.security.SecureRandom;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.security.spec.PKCS8EncodedKeySpec;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManagerFactory;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.signal.chat.rpc.GetRequestAttributesRequest;
|
||||
import org.signal.chat.rpc.GetRequestAttributesResponse;
|
||||
import org.signal.chat.rpc.RequestAttributesGrpc;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.CloseFrameEvent;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
class TlsWebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||
private NoiseWebSocketTunnelServer tlsNoiseWebSocketTunnelServer;
|
||||
private X509Certificate serverTlsCertificate;
|
||||
|
||||
|
||||
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
|
||||
// They were generated with:
|
||||
//
|
||||
// ```shell
|
||||
// openssl req -newkey ec:<(openssl ecparam -name secp384r1) -keyout test.key -nodes -x509 -days 36500 -out test.crt -subj "/CN=localhost"
|
||||
// ```
|
||||
private static final String SERVER_CERTIFICATE = """
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIBvDCCAUKgAwIBAgIUU16rjelaT/wClEM/SrW96VJbsiMwCgYIKoZIzj0EAwIw
|
||||
FDESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTI0MDEyNTIzMjA0OVoYDzIxMjQwMTAx
|
||||
MjMyMDQ5WjAUMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQA
|
||||
IgNiAAQOKblDCvMdPKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCV
|
||||
jttLE0TjLvgAvlJAO53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBq
|
||||
SlS8LQqjUzBRMB0GA1UdDgQWBBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAfBgNVHSME
|
||||
GDAWgBSk5UGHMmYrnaXZx+sZ1NixL5p0GTAPBgNVHRMBAf8EBTADAQH/MAoGCCqG
|
||||
SM49BAMCA2gAMGUCMC/2Nbz2niZzz+If26n1TS68GaBlPhEqQQH4kX+De6xfeLCw
|
||||
XcCmGFLqypzWFEF+8AIxAJ2Pok9Kv2Zn+wl5KnU7d7zOcrKBZHkjXXlkMso9RWsi
|
||||
iOr9sHiO8Rn2u0xRKgU5Ig==
|
||||
-----END CERTIFICATE-----
|
||||
""";
|
||||
|
||||
// BEGIN/END PRIVATE KEY header/footer removed for easier parsing
|
||||
private static final String SERVER_PRIVATE_KEY = """
|
||||
MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDSQpS2WpySnwihcuNj
|
||||
kOVBDXGOw2UbeG/DiFSNXunyQ+8DpyGSkKk4VsluPzrepXyhZANiAAQOKblDCvMd
|
||||
PKFZ7MRePDRbSnJ4fAUoyOlOfWW1UC7NH8X2Zug4DxCtjXCVjttLE0TjLvgAvlJA
|
||||
O53+WFZV6mAm9Hds2gXMLczRZZ7g74cHyh5qFRvKJh2GeDBqSlS8LQo=
|
||||
""";
|
||||
@Override
|
||||
protected void start(
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair serverKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws Exception {
|
||||
final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
|
||||
serverTlsCertificate = (X509Certificate) certificateFactory.generateCertificate(
|
||||
new ByteArrayInputStream(SERVER_CERTIFICATE.getBytes(StandardCharsets.UTF_8)));
|
||||
final PrivateKey serverTlsPrivateKey;
|
||||
final KeyFactory keyFactory = KeyFactory.getInstance("EC");
|
||||
serverTlsPrivateKey =
|
||||
keyFactory.generatePrivate(new PKCS8EncodedKeySpec(Base64.getMimeDecoder().decode(SERVER_PRIVATE_KEY)));
|
||||
tlsNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
||||
new X509Certificate[]{serverTlsCertificate},
|
||||
serverTlsPrivateKey,
|
||||
eventLoopGroup,
|
||||
delegatedTaskExecutor,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
recognizedProxySecret);
|
||||
tlsNoiseWebSocketTunnelServer.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void stop() throws InterruptedException {
|
||||
tlsNoiseWebSocketTunnelServer.stop();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup,
|
||||
final ECPublicKey serverPublicKey) {
|
||||
return new NoiseTunnelClient
|
||||
.Builder(tlsNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey)
|
||||
.setUseTls(serverTlsCertificate);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getRequestAttributes() throws InterruptedException {
|
||||
final String remoteAddress = "4.5.6.7";
|
||||
final String acceptLanguage = "en";
|
||||
final String userAgent = "Signal-Desktop/1.2.3 Linux";
|
||||
|
||||
final HttpHeaders headers = new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add("X-Forwarded-For", remoteAddress);
|
||||
|
||||
try (final NoiseTunnelClient client = anonymous()
|
||||
.setHeaders(headers)
|
||||
.setUserAgent(userAgent)
|
||||
.setAcceptLanguage(acceptLanguage)
|
||||
.build()) {
|
||||
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
|
||||
|
||||
assertEquals(remoteAddress, response.getRemoteAddress());
|
||||
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
|
||||
assertEquals(userAgent, response.getUserAgent());
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void connectAuthenticatedToAnonymousService() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
try (final NoiseTunnelClient client = authenticated()
|
||||
.setWebsocketUri(NoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI)
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void connectAnonymousToAuthenticatedService() throws InterruptedException, ExecutionException, TimeoutException {
|
||||
try (final NoiseTunnelClient client = anonymous()
|
||||
.setWebsocketUri(NoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI)
|
||||
.build()) {
|
||||
final ManagedChannel channel = buildManagedChannel(client.getLocalAddress());
|
||||
|
||||
try {
|
||||
//noinspection ResultOfMethodCallIgnored
|
||||
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
|
||||
() -> RequestAttributesGrpc.newBlockingStub(channel)
|
||||
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
}
|
||||
assertClosedWith(client, CloseFrameEvent.CloseReason.NOISE_HANDSHAKE_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectIllegalRequests() throws Exception {
|
||||
|
||||
final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
|
||||
keyStore.load(null, null);
|
||||
keyStore.setCertificateEntry("tunnel", serverTlsCertificate);
|
||||
|
||||
final TrustManagerFactory trustManagerFactory =
|
||||
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
|
||||
|
||||
trustManagerFactory.init(keyStore);
|
||||
|
||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||
sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom());
|
||||
|
||||
final URI authenticatedUri =
|
||||
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/authenticated",
|
||||
null, null);
|
||||
|
||||
final URI incorrectUri =
|
||||
new URI("https", null, "localhost", tlsNoiseWebSocketTunnelServer.getLocalAddress().getPort(), "/incorrect",
|
||||
null, null);
|
||||
|
||||
try (final HttpClient httpClient = HttpClient.newBuilder().sslContext(sslContext).build()) {
|
||||
assertEquals(405, httpClient.send(HttpRequest.newBuilder()
|
||||
.uri(authenticatedUri)
|
||||
.PUT(HttpRequest.BodyPublishers.ofString("test"))
|
||||
.build(),
|
||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||
"Non-GET requests should not be allowed");
|
||||
|
||||
assertEquals(426, httpClient.send(HttpRequest.newBuilder()
|
||||
.GET()
|
||||
.uri(authenticatedUri)
|
||||
.build(),
|
||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||
"GET requests without upgrade headers should not be allowed");
|
||||
|
||||
assertEquals(404, httpClient.send(HttpRequest.newBuilder()
|
||||
.GET()
|
||||
.uri(incorrectUri)
|
||||
.build(),
|
||||
HttpResponse.BodyHandlers.ofString()).statusCode(),
|
||||
"GET requests to unrecognized URIs should not be allowed");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import java.util.concurrent.Executor;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractNoiseTunnelServerIntegrationTest;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.client.NoiseTunnelClient;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
|
||||
|
||||
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractNoiseTunnelServerIntegrationTest {
|
||||
private NoiseWebSocketTunnelServer plaintextNoiseWebSocketTunnelServer;
|
||||
|
||||
@Override
|
||||
protected void start(
|
||||
final NioEventLoopGroup eventLoopGroup,
|
||||
final Executor delegatedTaskExecutor,
|
||||
final GrpcClientConnectionManager grpcClientConnectionManager,
|
||||
final ClientPublicKeysManager clientPublicKeysManager,
|
||||
final ECKeyPair serverKeyPair,
|
||||
final LocalAddress authenticatedGrpcServerAddress,
|
||||
final LocalAddress anonymousGrpcServerAddress,
|
||||
final String recognizedProxySecret) throws Exception {
|
||||
plaintextNoiseWebSocketTunnelServer = new NoiseWebSocketTunnelServer(0,
|
||||
null,
|
||||
null,
|
||||
eventLoopGroup,
|
||||
delegatedTaskExecutor,
|
||||
grpcClientConnectionManager,
|
||||
clientPublicKeysManager,
|
||||
serverKeyPair,
|
||||
authenticatedGrpcServerAddress,
|
||||
anonymousGrpcServerAddress,
|
||||
recognizedProxySecret);
|
||||
plaintextNoiseWebSocketTunnelServer.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void stop() throws InterruptedException {
|
||||
plaintextNoiseWebSocketTunnelServer.stop();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NoiseTunnelClient.Builder clientBuilder(final NioEventLoopGroup eventLoopGroup, final ECPublicKey serverPublicKey) {
|
||||
return new NoiseTunnelClient
|
||||
.Builder(plaintextNoiseWebSocketTunnelServer.getLocalAddress(), eventLoopGroup, serverPublicKey);
|
||||
}
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpRequest;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.FullHttpResponse;
|
||||
import io.netty.handler.codec.http.HttpHeaderNames;
|
||||
import io.netty.handler.codec.http.HttpHeaderValues;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpMethod;
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import io.netty.handler.codec.http.HttpVersion;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||
|
||||
class WebSocketOpeningHandshakeHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private EmbeddedChannel embeddedChannel;
|
||||
|
||||
private static final String AUTHENTICATED_PATH = "/authenticated";
|
||||
private static final String ANONYMOUS_PATH = "/anonymous";
|
||||
private static final String HEALTH_CHECK_PATH = "/health-check";
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
embeddedChannel =
|
||||
new EmbeddedChannel(new WebSocketOpeningHandshakeHandler(AUTHENTICATED_PATH, ANONYMOUS_PATH, HEALTH_CHECK_PATH));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||
void handleValidRequest(final String path) {
|
||||
final FullHttpRequest request = buildRequest(HttpMethod.GET, path,
|
||||
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||
|
||||
try {
|
||||
embeddedChannel.writeOneInbound(request);
|
||||
|
||||
assertEquals(1, request.refCnt());
|
||||
assertEquals(1, embeddedChannel.inboundMessages().size());
|
||||
assertEquals(request, embeddedChannel.inboundMessages().poll());
|
||||
} finally {
|
||||
request.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleHealthCheckRequest() {
|
||||
final FullHttpRequest request = buildRequest(HttpMethod.GET, HEALTH_CHECK_PATH, new DefaultHttpHeaders());
|
||||
|
||||
embeddedChannel.writeOneInbound(request);
|
||||
|
||||
assertEquals(0, request.refCnt());
|
||||
assertHttpResponse(HttpResponseStatus.NO_CONTENT);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||
void handleUpgradeRequired(final String path) {
|
||||
final FullHttpRequest request = buildRequest(HttpMethod.GET, path, new DefaultHttpHeaders());
|
||||
|
||||
embeddedChannel.writeOneInbound(request);
|
||||
|
||||
assertEquals(0, request.refCnt());
|
||||
assertHttpResponse(HttpResponseStatus.UPGRADE_REQUIRED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleBadPath() {
|
||||
final FullHttpRequest request = buildRequest(HttpMethod.GET, "/incorrect",
|
||||
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||
|
||||
embeddedChannel.writeOneInbound(request);
|
||||
|
||||
assertEquals(0, request.refCnt());
|
||||
assertHttpResponse(HttpResponseStatus.NOT_FOUND);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = { AUTHENTICATED_PATH, ANONYMOUS_PATH })
|
||||
void handleMethodNotAllowed(final String path) {
|
||||
final FullHttpRequest request = buildRequest(HttpMethod.DELETE, path,
|
||||
new DefaultHttpHeaders().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET));
|
||||
|
||||
embeddedChannel.writeOneInbound(request);
|
||||
|
||||
assertEquals(0, request.refCnt());
|
||||
assertHttpResponse(HttpResponseStatus.METHOD_NOT_ALLOWED);
|
||||
}
|
||||
|
||||
private void assertHttpResponse(final HttpResponseStatus expectedStatus) {
|
||||
assertEquals(1, embeddedChannel.outboundMessages().size());
|
||||
|
||||
final FullHttpResponse response = assertInstanceOf(FullHttpResponse.class, embeddedChannel.outboundMessages().poll());
|
||||
|
||||
assertEquals(expectedStatus, response.status());
|
||||
}
|
||||
|
||||
private FullHttpRequest buildRequest(final HttpMethod method, final String path, final HttpHeaders headers) {
|
||||
return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
|
||||
method,
|
||||
path,
|
||||
Unpooled.buffer(0),
|
||||
headers,
|
||||
new DefaultHttpHeaders());
|
||||
}
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
package org.whispersystems.textsecuregcm.grpc.net.websocket;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
|
||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||
|
||||
import com.google.common.net.InetAddresses;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.channel.local.LocalAddress;
|
||||
import io.netty.handler.codec.http.DefaultHttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
|
||||
import java.net.InetAddress;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.AbstractLeakDetectionTest;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.HandshakePattern;
|
||||
import org.whispersystems.textsecuregcm.grpc.net.NoiseHandshakeInit;
|
||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||
|
||||
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
|
||||
|
||||
private UserEventRecordingHandler userEventRecordingHandler;
|
||||
private MutableRemoteAddressEmbeddedChannel embeddedChannel;
|
||||
|
||||
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.secure().nextAlphanumeric(16);
|
||||
|
||||
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
|
||||
|
||||
private final List<Object> receivedEvents = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
|
||||
receivedEvents.add(event);
|
||||
}
|
||||
|
||||
public List<Object> getReceivedEvents() {
|
||||
return receivedEvents;
|
||||
}
|
||||
}
|
||||
|
||||
private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel {
|
||||
|
||||
private SocketAddress remoteAddress;
|
||||
|
||||
public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) {
|
||||
super(handlers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SocketAddress remoteAddress0() {
|
||||
return isActive() ? remoteAddress : null;
|
||||
}
|
||||
|
||||
public void setRemoteAddress(final SocketAddress remoteAddress) {
|
||||
this.remoteAddress = remoteAddress;
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
userEventRecordingHandler = new UserEventRecordingHandler();
|
||||
|
||||
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
|
||||
new WebsocketHandshakeCompleteHandler(RECOGNIZED_PROXY_SECRET),
|
||||
userEventRecordingHandler);
|
||||
|
||||
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void handleWebSocketHandshakeComplete(final String uri, final HandshakePattern pattern) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
|
||||
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||
assertNotNull(init);
|
||||
assertEquals(init.getHandshakePattern(), pattern);
|
||||
}
|
||||
|
||||
private static List<Arguments> handleWebSocketHandshakeComplete() {
|
||||
return List.of(
|
||||
Arguments.of(NoiseWebSocketTunnelServer.AUTHENTICATED_SERVICE_PATH, HandshakePattern.IK),
|
||||
Arguments.of(NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, HandshakePattern.NK));
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleWebSocketHandshakeCompleteUnexpectedPath() {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleUnrecognizedEvent() {
|
||||
final Object unrecognizedEvent = new Object();
|
||||
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
|
||||
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) {
|
||||
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
|
||||
new WebSocketServerProtocolHandler.HandshakeComplete(
|
||||
NoiseWebSocketTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null);
|
||||
|
||||
embeddedChannel.setRemoteAddress(remoteAddress);
|
||||
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
|
||||
|
||||
final byte[] payload = TestRandomUtil.nextBytes(100);
|
||||
embeddedChannel.pipeline().fireChannelRead(Unpooled.wrappedBuffer(payload));
|
||||
final NoiseHandshakeInit init = (NoiseHandshakeInit) embeddedChannel.inboundMessages().poll();
|
||||
assertEquals(
|
||||
expectedRemoteAddress,
|
||||
Optional.ofNullable(init)
|
||||
.map(NoiseHandshakeInit::getRemoteAddress)
|
||||
.orElse(null));
|
||||
if (expectedRemoteAddress == null) {
|
||||
assertThrows(IllegalStateException.class, embeddedChannel::checkException);
|
||||
} else {
|
||||
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
|
||||
}
|
||||
}
|
||||
|
||||
private static List<Arguments> getRemoteAddress() {
|
||||
final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0);
|
||||
final InetAddress clientAddress = InetAddresses.forString("1.2.3.4");
|
||||
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
|
||||
|
||||
return List.of(
|
||||
argumentSet("Recognized proxy, single forwarded-for address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
clientAddress),
|
||||
|
||||
argumentSet("Recognized proxy, multiple forwarded-for addresses",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
proxyAddress),
|
||||
|
||||
argumentSet("No recognized proxy header, single forwarded-for address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
argumentSet("No recognized proxy header, no forwarded-for address",
|
||||
new DefaultHttpHeaders(),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
argumentSet("Incorrect proxy header, single forwarded-for address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
argumentSet("Recognized proxy, no forwarded-for address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
|
||||
remoteAddress,
|
||||
remoteAddress.getAddress()),
|
||||
|
||||
argumentSet("Recognized proxy, bogus forwarded-for address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
|
||||
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
|
||||
remoteAddress,
|
||||
null),
|
||||
|
||||
argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
|
||||
new DefaultHttpHeaders()
|
||||
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
|
||||
new LocalAddress("local-address"),
|
||||
null)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
@ParameterizedTest
|
||||
@MethodSource("argumentsForGetMostRecentProxy")
|
||||
void getMostRecentProxy(final String forwardedFor, final Optional<String> expectedMostRecentProxy) {
|
||||
assertEquals(expectedMostRecentProxy, WebsocketHandshakeCompleteHandler.getMostRecentProxy(forwardedFor));
|
||||
}
|
||||
|
||||
private static Stream<Arguments> argumentsForGetMostRecentProxy() {
|
||||
return Stream.of(
|
||||
arguments(null, Optional.empty()),
|
||||
arguments("", Optional.empty()),
|
||||
arguments(" ", Optional.empty()),
|
||||
arguments("203.0.113.195,", Optional.empty()),
|
||||
arguments("203.0.113.195, ", Optional.empty()),
|
||||
arguments("203.0.113.195", Optional.of("203.0.113.195")),
|
||||
arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178"))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -172,8 +172,3 @@ turn.cloudflare.apiToken: ABCDEFGHIJKLM
|
||||
linkDevice.secret: AAAAAAAAAAA=
|
||||
|
||||
tlsKeyStore.password: unset
|
||||
|
||||
# The below private key was generated exclusively for testing purposes. Do not use it in any other context.
|
||||
# Corresponding public key: cYUAFtkWK/4x3AfW/yw7qgIo/mQUaRSWaPolGQkiL14=
|
||||
noiseTunnel.noiseStaticPrivateKey: qK5FD9WmuhoLPsS/Z4swcZkwDn9OpeM5ZmcEVMpEQ24=
|
||||
noiseTunnel.recognizedProxySecret: ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789AAAAAAA
|
||||
|
||||
@@ -525,12 +525,6 @@ turn:
|
||||
linkDevice:
|
||||
secret: secret://linkDevice.secret
|
||||
|
||||
noiseTunnel:
|
||||
webSocketPort: 8444
|
||||
directPort: 8445
|
||||
noiseStaticPrivateKey: secret://noiseTunnel.noiseStaticPrivateKey
|
||||
recognizedProxySecret: secret://noiseTunnel.recognizedProxySecret
|
||||
|
||||
externalRequestFilter:
|
||||
grpcMethods:
|
||||
- com.example.grpc.ExampleService/exampleMethod
|
||||
@@ -541,3 +535,6 @@ externalRequestFilter:
|
||||
|
||||
idlePrimaryDeviceReminder:
|
||||
minIdleDuration: P30D
|
||||
|
||||
grpc:
|
||||
port: 50051
|
||||
|
||||
Reference in New Issue
Block a user