Introduce a Noise-over-WebSocket client connection manager

This commit is contained in:
Jon Chambers
2024-03-22 15:20:55 -04:00
committed by GitHub
parent 075a08884b
commit aec6ac019f
53 changed files with 1818 additions and 933 deletions

View File

@@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration;
import org.whispersystems.textsecuregcm.configuration.NoiseWebSocketTunnelConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
@@ -333,6 +334,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private MonitoredS3ObjectConfiguration callingTurnManualTable;
@Valid
@NotNull
@JsonProperty
private NoiseWebSocketTunnelConfiguration noiseTunnel;
public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
return tlsKeyStore;
}
@@ -555,4 +561,8 @@ public class WhisperServerConfiguration extends Configuration {
public MonitoredS3ObjectConfiguration getCallingTurnManualTable() {
return callingTurnManualTable;
}
public NoiseWebSocketTunnelConfiguration getNoiseWebSocketTunnelConfiguration() {
return noiseTunnel;
}
}

View File

@@ -71,7 +71,8 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.backup.BackupManager;
import org.whispersystems.textsecuregcm.backup.BackupsDb;
@@ -127,7 +128,6 @@ import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager;
import org.whispersystems.textsecuregcm.grpc.AcceptLanguageInterceptor;
import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.AccountsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ErrorMappingInterceptor;
@@ -138,7 +138,8 @@ import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
import org.whispersystems.textsecuregcm.grpc.PaymentsGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileAnonymousGrpcService;
import org.whispersystems.textsecuregcm.grpc.ProfileGrpcService;
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.grpc.net.ManagedDefaultEventLoopGroup;
import org.whispersystems.textsecuregcm.grpc.net.ManagedLocalGrpcServer;
import org.whispersystems.textsecuregcm.jetty.JettyHttpConfigurationCustomizer;
@@ -752,8 +753,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
geoIpCityDatabaseManager
);
final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor =
new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager));
final ClientConnectionManager clientConnectionManager = new ClientConnectionManager();
final ManagedDefaultEventLoopGroup localEventLoopGroup = new ManagedDefaultEventLoopGroup();
@@ -762,8 +762,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MetricCollectingServerInterceptor(Metrics.globalRegistry);
final ErrorMappingInterceptor errorMappingInterceptor = new ErrorMappingInterceptor();
final AcceptLanguageInterceptor acceptLanguageInterceptor = new AcceptLanguageInterceptor();
final UserAgentInterceptor userAgentInterceptor = new UserAgentInterceptor();
final RequestAttributesInterceptor requestAttributesInterceptor =
new RequestAttributesInterceptor(clientConnectionManager);
final LocalAddress anonymousGrpcServerAddress = new LocalAddress("grpc-anonymous");
final LocalAddress authenticatedGrpcServerAddress = new LocalAddress("grpc-authenticated");
@@ -778,9 +778,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
// TODO: specialize metrics with user-agent platform
.intercept(metricCollectingServerInterceptor)
.intercept(errorMappingInterceptor)
.intercept(acceptLanguageInterceptor)
.intercept(remoteDeprecationFilter)
.intercept(userAgentInterceptor)
.intercept(requestAttributesInterceptor)
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager))
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
.addService(new PaymentsGrpcService(currencyManager))
@@ -799,10 +799,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
// TODO: specialize metrics with user-agent platform
.intercept(metricCollectingServerInterceptor)
.intercept(errorMappingInterceptor)
.intercept(acceptLanguageInterceptor)
.intercept(remoteDeprecationFilter)
.intercept(userAgentInterceptor)
.intercept(new BasicCredentialAuthenticationInterceptor(new AccountAuthenticator(accountsManager)))
.intercept(requestAttributesInterceptor)
.intercept(new RequireAuthenticationInterceptor(clientConnectionManager))
.addService(new AccountsGrpcService(accountsManager, rateLimiters, usernameHashZkProofVerifier, registrationRecoveryPasswordsManager))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(new KeysGrpcService(accountsManager, keysManager, rateLimiters))

View File

@@ -0,0 +1,34 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import java.util.Optional;
abstract class AbstractAuthenticationInterceptor implements ServerInterceptor {
private final ClientConnectionManager clientConnectionManager;
private static final Metadata EMPTY_TRAILERS = new Metadata();
AbstractAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
this.clientConnectionManager = clientConnectionManager;
}
protected Optional<AuthenticatedDevice> getAuthenticatedDevice(final ServerCall<?, ?> call) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
return clientConnectionManager.getAuthenticatedDevice(localAddress);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
protected <ReqT, RespT> ServerCall.Listener<ReqT> closeAsUnauthenticated(final ServerCall<ReqT, RespT> call) {
call.close(Status.UNAUTHENTICATED, EMPTY_TRAILERS);
return new ServerCall.Listener<>() {};
}
}

View File

@@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Context;
import io.grpc.Status;
import java.util.UUID;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -16,8 +15,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
*/
public class AuthenticationUtil {
static final Context.Key<UUID> CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci");
static final Context.Key<Byte> CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
static final Context.Key<AuthenticatedDevice> CONTEXT_AUTHENTICATED_DEVICE = Context.key("authenticated-device");
/**
* Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if
@@ -29,11 +27,10 @@ public class AuthenticationUtil {
* could be retrieved from the current gRPC context
*/
public static AuthenticatedDevice requireAuthenticatedDevice() {
@Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get();
@Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
@Nullable final AuthenticatedDevice authenticatedDevice = CONTEXT_AUTHENTICATED_DEVICE.get();
if (accountIdentifier != null && deviceId != null) {
return new AuthenticatedDevice(accountIdentifier, deviceId);
if (authenticatedDevice != null) {
return authenticatedDevice;
}
throw Status.UNAUTHENTICATED.asRuntimeException();

View File

@@ -1,85 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.basic.BasicCredentials;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
/**
* A basic credential authentication interceptor enforces the presence of a valid username and password on every call.
* Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the
* {@code x-signal-basic-auth-credentials} call header.
* <p/>
* Downstream services can retrieve the identity of the authenticated caller using methods in
* {@link AuthenticationUtil}.
* <p/>
* Note that this authentication, while fully functional, is intended only for development and testing purposes and is
* intended to be replaced with a more robust and efficient strategy before widespread client adoption.
*
* @see AuthenticationUtil
* @see AccountAuthenticator
*/
public class BasicCredentialAuthenticationInterceptor implements ServerInterceptor {
private final AccountAuthenticator accountAuthenticator;
@VisibleForTesting
static final Metadata.Key<String> BASIC_CREDENTIALS =
Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER);
private static final Metadata EMPTY_TRAILERS = new Metadata();
public BasicCredentialAuthenticationInterceptor(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final String authHeader = headers.get(BASIC_CREDENTIALS);
if (StringUtils.isNotBlank(authHeader)) {
final Optional<BasicCredentials> maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader);
if (maybeCredentials.isEmpty()) {
call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS);
} else {
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
accountAuthenticator.authenticate(maybeCredentials.get());
if (maybeAuthenticatedAccount.isPresent()) {
final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get();
final Context context = Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedAccount.getAccount().getUuid())
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedAccount.getAuthenticatedDevice().getId());
return Contexts.interceptCall(context, call, headers, next);
} else {
call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS);
}
}
} else {
call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS);
}
return new ServerCall.Listener<>() {};
}
}

View File

@@ -0,0 +1,28 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
/**
* A "prohibit authentication" interceptor ensures that requests to endpoints that should be invoked anonymously do not
* originate from a channel that is associated with an authenticated device. Calls with an associated authenticated
* device are closed with an {@code UNAUTHENTICATED} status.
*/
public class ProhibitAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public ProhibitAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
super(clientConnectionManager);
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call)
.map(ignored -> closeAsUnauthenticated(call))
.orElseGet(() -> next.startCall(call, headers));
}
}

View File

@@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
/**
* A "require authentication" interceptor requires that requests be issued from a connection that is associated with an
* authenticated device. Calls without an associated authenticated device are closed with an {@code UNAUTHENTICATED}
* status.
*/
public class RequireAuthenticationInterceptor extends AbstractAuthenticationInterceptor {
public RequireAuthenticationInterceptor(final ClientConnectionManager clientConnectionManager) {
super(clientConnectionManager);
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return getAuthenticatedDevice(call)
.map(authenticatedDevice -> Contexts.interceptCall(Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next))
.orElseGet(() -> closeAsUnauthenticated(call));
}
}

View File

@@ -0,0 +1,8 @@
package org.whispersystems.textsecuregcm.configuration;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Positive;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretString;
public record NoiseWebSocketTunnelConfiguration(@Positive int port, @NotNull SecretString recognizedProxySecret) {
}

View File

@@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.filters;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.Optional;
import javax.annotation.Nullable;
@@ -72,8 +71,7 @@ public class RemoteAddressFilter implements Filter {
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a>
*/
@VisibleForTesting
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor)
.map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1;

View File

@@ -17,6 +17,7 @@ import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
@@ -26,6 +27,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesUtil;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@@ -79,7 +81,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (shouldBlock(UserAgentUtil.userAgentFromGrpcContext())) {
if (shouldBlock(RequestAttributesUtil.getUserAgent().orElse(null))) {
call.close(StatusConstants.UPGRADE_NEEDED_STATUS, new Metadata());
return new ServerCall.Listener<>() {};
} else {
@@ -87,7 +89,7 @@ public class RemoteDeprecationFilter implements Filter, ServerInterceptor {
}
}
private boolean shouldBlock(final UserAgent userAgent) {
private boolean shouldBlock(@Nullable final UserAgent userAgent) {
final DynamicRemoteDeprecationConfiguration configuration = dynamicConfigurationManager
.getConfiguration().getRemoteDeprecationConfiguration();
final Map<ClientPlatform, Semver> minimumVersionsByPlatform = configuration.getMinimumVersions();

View File

@@ -1,68 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class AcceptLanguageInterceptor implements ServerInterceptor {
private static final Logger logger = LoggerFactory.getLogger(AcceptLanguageInterceptor.class);
private static final String INVALID_ACCEPT_LANGUAGE_COUNTER_NAME = name(AcceptLanguageInterceptor.class, "invalidAcceptLanguage");
@VisibleForTesting
public static final Metadata.Key<String> ACCEPTABLE_LANGUAGES_GRPC_HEADER =
Metadata.Key.of("accept-language", Metadata.ASCII_STRING_MARSHALLER);
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
final List<Locale> locales = parseLocales(headers.get(ACCEPTABLE_LANGUAGES_GRPC_HEADER));
return Contexts.interceptCall(
Context.current().withValue(AcceptLanguageUtil.ACCEPTABLE_LANGUAGES_CONTEXT_KEY, locales),
call,
headers,
next);
}
static List<Locale> parseLocales(@Nullable final String acceptableLanguagesHeader) {
if (acceptableLanguagesHeader == null) {
return Collections.emptyList();
}
try {
final List<Locale.LanguageRange> languageRanges = Locale.LanguageRange.parse(acceptableLanguagesHeader);
return Locale.filter(languageRanges, Arrays.asList(Locale.getAvailableLocales()));
} catch (final IllegalArgumentException e) {
final UserAgent userAgent = UserAgentUtil.userAgentFromGrpcContext();
Metrics.counter(INVALID_ACCEPT_LANGUAGE_COUNTER_NAME, "platform", userAgent.getPlatform().name().toLowerCase()).increment();
logger.debug("Could not get acceptable languages; Accept-Language: {}; User-Agent: {}",
acceptableLanguagesHeader,
userAgent,
e);
return Collections.emptyList();
}
}
}

View File

@@ -1,17 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.util.List;
import java.util.Locale;
public class AcceptLanguageUtil {
static final Context.Key<List<Locale>> ACCEPTABLE_LANGUAGES_CONTEXT_KEY = Context.key("accept-language");
public static List<Locale> localeFromGrpcContext() {
return ACCEPTABLE_LANGUAGES_CONTEXT_KEY.get();
}
}

View File

@@ -111,7 +111,7 @@ public class ProfileGrpcHelper {
case ACI -> {
responseBuilder.setUnrestrictedUnidentifiedAccess(targetAccount.isUnrestrictedUnidentifiedAccess())
.addAllBadges(buildBadges(profileBadgeConverter.convert(
AcceptLanguageUtil.localeFromGrpcContext(),
RequestAttributesUtil.getAvailableAcceptedLocales(),
targetAccount.getBadges(),
ProfileHelper.isSelfProfileRequest(requesterUuid, (AciServiceIdentifier) targetIdentifier))));

View File

@@ -5,28 +5,12 @@
package org.whispersystems.textsecuregcm.grpc;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import reactor.core.publisher.Mono;
import java.net.SocketAddress;
import java.time.Duration;
class RateLimitUtil {
private static final RateLimitExceededException UNKNOWN_REMOTE_ADDRESS_EXCEPTION =
new RateLimitExceededException(Duration.ofHours(1), true);
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter) {
return rateLimitByRemoteAddress(rateLimiter, true);
}
static Mono<Void> rateLimitByRemoteAddress(final RateLimiter rateLimiter, final boolean failOnUnknownRemoteAddress) {
final SocketAddress remoteAddress = RemoteAddressUtil.getRemoteAddress();
if (remoteAddress != null) {
return rateLimiter.validateReactive(remoteAddress.toString());
} else {
return failOnUnknownRemoteAddress ? Mono.error(UNKNOWN_REMOTE_ADDRESS_EXCEPTION) : Mono.empty();
}
return rateLimiter.validateReactive(RequestAttributesUtil.getRemoteAddress().getHostAddress());
}
}

View File

@@ -1,33 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.SocketAddress;
public class RemoteAddressInterceptor implements ServerInterceptor {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
// Note: the specific implementation for getting a remote client address may change depending on the client
// connection strategy. The important thing is that the remote address wind up in the context for the current
// call so it can be retrieved by `RemoteAddressUtil`.
final SocketAddress remoteAddress = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
return Contexts.interceptCall(
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
serverCall, headers, next);
}
}

View File

@@ -1,23 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.SocketAddress;
public class RemoteAddressUtil {
static final Context.Key<SocketAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
/**
* Returns the socket address of the remote client in the current gRPC request context.
*
* @return the socket address of the remote client
*/
public static SocketAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
}

View File

@@ -0,0 +1,75 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.netty.channel.local.LocalAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
public class RequestAttributesInterceptor implements ServerInterceptor {
private final ClientConnectionManager clientConnectionManager;
private static final Logger log = LoggerFactory.getLogger(RequestAttributesInterceptor.class);
public RequestAttributesInterceptor(final ClientConnectionManager clientConnectionManager) {
this.clientConnectionManager = clientConnectionManager;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR) instanceof LocalAddress localAddress) {
Context context = Context.current();
{
final Optional<InetAddress> maybeRemoteAddress = clientConnectionManager.getRemoteAddress(localAddress);
if (maybeRemoteAddress.isEmpty()) {
// We should never have a call from a party whose remote address we can't identify
log.warn("No remote address available");
call.close(Status.INTERNAL, new Metadata());
return new ServerCall.Listener<>() {};
}
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, maybeRemoteAddress.get());
}
{
final Optional<List<Locale.LanguageRange>> maybeAcceptLanguage =
clientConnectionManager.getAcceptableLanguages(localAddress);
if (maybeAcceptLanguage.isPresent()) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, maybeAcceptLanguage.get());
}
}
{
final Optional<UserAgent> maybeUserAgent = clientConnectionManager.getUserAgent(localAddress);
if (maybeUserAgent.isPresent()) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, maybeUserAgent.get());
}
}
return Contexts.interceptCall(context, call, headers, next);
} else {
throw new AssertionError("Unexpected channel type: " + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
}
}
}

View File

@@ -0,0 +1,59 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class RequestAttributesUtil {
static final Context.Key<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_CONTEXT_KEY = Context.key("accept-language");
static final Context.Key<InetAddress> REMOTE_ADDRESS_CONTEXT_KEY = Context.key("remote-address");
static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("user-agent");
private static final List<Locale> AVAILABLE_LOCALES = Arrays.asList(Locale.getAvailableLocales());
/**
* Returns the acceptable languages listed by the remote client in the current gRPC request context.
*
* @return the acceptable languages listed by the remote client; may be empty if unparseable or not specified
*/
public static Optional<List<Locale.LanguageRange>> getAcceptableLanguages() {
return Optional.ofNullable(ACCEPT_LANGUAGE_CONTEXT_KEY.get());
}
/**
* Returns a list of distinct locales supported by the JVM and accepted by the remote client in the current gRPC
* context. May be empty if the client did not supply a list of acceptable languages, if the list of acceptable
* languages could not be parsed, or if none of the acceptable languages are available in the current JVM.
*
* @return a list of distinct locales acceptable to the remote client and available in this JVM
*/
public static List<Locale> getAvailableAcceptedLocales() {
return getAcceptableLanguages()
.map(languageRanges -> Locale.filter(languageRanges, AVAILABLE_LOCALES))
.orElseGet(Collections::emptyList);
}
/**
* Returns the remote address of the remote client in the current gRPC request context.
*
* @return the remote address of the remote client
*/
public static InetAddress getRemoteAddress() {
return REMOTE_ADDRESS_CONTEXT_KEY.get();
}
/**
* Returns the parsed user-agent of the remote client in the current gRPC request context.
*
* @return the parsed user-agent of the remote client; may be null if unparseable or not specified
*/
public static Optional<UserAgent> getUserAgent() {
return Optional.ofNullable(USER_AGENT_CONTEXT_KEY.get());
}
}

View File

@@ -1,43 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
public class UserAgentInterceptor implements ServerInterceptor {
@VisibleForTesting
public static final Metadata.Key<String> USER_AGENT_GRPC_HEADER =
Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER);
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> call,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
UserAgent userAgent;
try {
userAgent = UserAgentUtil.parseUserAgentString(headers.get(USER_AGENT_GRPC_HEADER));
} catch (final UnrecognizedUserAgentException e) {
userAgent = null;
}
final Context context = Context.current().withValue(UserAgentUtil.USER_AGENT_CONTEXT_KEY, userAgent);
return Contexts.interceptCall(context, call, headers, next);
}
}

View File

@@ -0,0 +1,218 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.util.AttributeKey;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
/**
* A client connection manager associates a local connection to a local gRPC server with a remote connection through a
* Noise-over-WebSocket tunnel. It provides access to metadata associated with the remote connection, including the
* authenticated identity of the device that opened the connection (for non-anonymous connections). It can also close
* connections associated with a given device if that device's credentials have changed and clients must reauthenticate.
*/
public class ClientConnectionManager {
private final Map<LocalAddress, Channel> remoteChannelsByLocalAddress = new ConcurrentHashMap<>();
private final Map<AuthenticatedDevice, List<Channel>> remoteChannelsByAuthenticatedDevice = new ConcurrentHashMap<>();
@VisibleForTesting
static final AttributeKey<AuthenticatedDevice> AUTHENTICATED_DEVICE_ATTRIBUTE_KEY =
AttributeKey.valueOf(ClientConnectionManager.class, "authenticatedDevice");
@VisibleForTesting
static final AttributeKey<InetAddress> REMOTE_ADDRESS_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "remoteAddress");
@VisibleForTesting
static final AttributeKey<String> RAW_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "rawUserAgent");
@VisibleForTesting
static final AttributeKey<UserAgent> PARSED_USER_AGENT_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "userAgent");
@VisibleForTesting
static final AttributeKey<List<Locale.LanguageRange>> ACCEPT_LANGUAGE_ATTRIBUTE_KEY =
AttributeKey.valueOf(WebsocketHandshakeCompleteHandler.class, "acceptLanguage");
private static final Logger log = LoggerFactory.getLogger(ClientConnectionManager.class);
/**
* Returns the authenticated device associated with the given local address, if any. An authenticated device is
* available if and only if the given local address maps to an active local connection and that connection is
* authenticated (i.e. not anonymous).
*
* @param localAddress the local address for which to find an authenticated device
*
* @return the authenticated device associated with the given local address, if any
*/
public Optional<AuthenticatedDevice> getAuthenticatedDevice(final LocalAddress localAddress) {
return getAuthenticatedDevice(remoteChannelsByLocalAddress.get(localAddress));
}
private Optional<AuthenticatedDevice> getAuthenticatedDevice(@Nullable final Channel remoteChannel) {
return Optional.ofNullable(remoteChannel)
.map(channel -> channel.attr(AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).get());
}
/**
* Returns the parsed acceptable languages associated with the given local address, if any. Acceptable languages may
* be unavailable if the local connection associated with the given local address has already closed, if the client
* did not provide a list of acceptable languages, or the list provided by the client could not be parsed.
*
* @param localAddress the local address for which to find acceptable languages
*
* @return the acceptable languages associated with the given local address, if any
*/
public Optional<List<Locale.LanguageRange>> getAcceptableLanguages(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
/**
* Returns the remote address associated with the given local address, if any. A remote address may be unavailable if
* the local connection associated with the given local address has already closed.
*
* @param localAddress the local address for which to find a remote address
*
* @return the remote address associated with the given local address, if any
*/
public Optional<InetAddress> getRemoteAddress(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
/**
* Returns the parsed user agent provided by the client the opened the connection associated with the given local
* address. This method may return an empty value if no active local connection is associated with the given local
* address or if the client's user-agent string was not recognized.
*
* @param localAddress the local address for which to find a User-Agent string
*
* @return the user agent associated with the given local address
*/
public Optional<UserAgent> getUserAgent(final LocalAddress localAddress) {
return Optional.ofNullable(remoteChannelsByLocalAddress.get(localAddress))
.map(remoteChannel -> remoteChannel.attr(PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
/**
* Closes any client connections to this host associated with the given authenticated device.
*
* @param authenticatedDevice the authenticated device for which to close connections
*/
public void closeConnection(final AuthenticatedDevice authenticatedDevice) {
// Channels will actually get removed from the list/map by their closeFuture listeners
remoteChannelsByAuthenticatedDevice.getOrDefault(authenticatedDevice, Collections.emptyList()).forEach(channel ->
channel.writeAndFlush(new CloseWebSocketFrame(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED
.toWebSocketCloseStatus("Reauthentication required")))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE));
}
@VisibleForTesting
@Nullable List<Channel> getRemoteChannelsByAuthenticatedDevice(final AuthenticatedDevice authenticatedDevice) {
return remoteChannelsByAuthenticatedDevice.get(authenticatedDevice);
}
@VisibleForTesting
Channel getRemoteChannelByLocalAddress(final LocalAddress localAddress) {
return remoteChannelsByLocalAddress.get(localAddress);
}
/**
* Handles successful completion of a WebSocket handshake and associates attributes and headers from the handshake
* request with the channel via which the handshake took place.
*
* @param channel the channel that completed a WebSocket handshake
* @param preferredRemoteAddress the preferred remote address (potentially from a request header) for the handshake
* @param userAgentHeader the value of the User-Agent header provided in the handshake request; may be {@code null}
* @param acceptLanguageHeader the value of the Accept-Language header provided in the handshake request; may be
* {@code null}
*/
static void handleWebSocketHandshakeComplete(final Channel channel,
final InetAddress preferredRemoteAddress,
@Nullable final String userAgentHeader,
@Nullable final String acceptLanguageHeader) {
channel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(preferredRemoteAddress);
if (StringUtils.isNotBlank(userAgentHeader)) {
channel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).set(userAgentHeader);
try {
channel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY)
.set(UserAgentUtil.parseUserAgentString(userAgentHeader));
} catch (final UnrecognizedUserAgentException ignored) {
}
}
if (StringUtils.isNotBlank(acceptLanguageHeader)) {
try {
channel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(Locale.LanguageRange.parse(acceptLanguageHeader));
} catch (final IllegalArgumentException e) {
log.debug("Invalid Accept-Language header from User-Agent {}: {}", userAgentHeader, acceptLanguageHeader, e);
}
}
}
/**
* Handles successful establishment of a Noise-over-WebSocket connection from a remote client to a local gRPC server.
*
* @param localChannel the newly-opened local channel between the Noise-over-WebSocket tunnel and the local gRPC
* server
* @param remoteChannel the channel from the remote client to the Noise-over-WebSocket tunnel
* @param maybeAuthenticatedDevice the authenticated device (if any) associated with the new connection
*/
void handleConnectionEstablished(final LocalChannel localChannel,
final Channel remoteChannel,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
maybeAuthenticatedDevice.ifPresent(authenticatedDevice ->
remoteChannel.attr(ClientConnectionManager.AUTHENTICATED_DEVICE_ATTRIBUTE_KEY).set(authenticatedDevice));
remoteChannelsByLocalAddress.put(localChannel.localAddress(), remoteChannel);
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
final List<Channel> channels = existingChannelList != null ? existingChannelList : new ArrayList<>();
channels.add(remoteChannel);
return channels;
}));
remoteChannel.closeFuture().addListener(closeFuture -> {
remoteChannelsByLocalAddress.remove(localChannel.localAddress());
getAuthenticatedDevice(remoteChannel).ifPresent(authenticatedDevice ->
remoteChannelsByAuthenticatedDevice.compute(authenticatedDevice, (ignored, existingChannelList) -> {
if (existingChannelList == null) {
return null;
}
existingChannelList.remove(remoteChannel);
return existingChannelList.isEmpty() ? null : existingChannelList;
}));
});
}
}

View File

@@ -22,15 +22,21 @@ import org.slf4j.LoggerFactory;
*/
class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
private final ClientConnectionManager clientConnectionManager;
private final LocalAddress authenticatedGrpcServerAddress;
private final LocalAddress anonymousGrpcServerAddress;
private final List<Object> pendingReads = new ArrayList<>();
private static final Logger log = LoggerFactory.getLogger(EstablishLocalGrpcConnectionHandler.class);
public EstablishLocalGrpcConnectionHandler(final LocalAddress authenticatedGrpcServerAddress,
public EstablishLocalGrpcConnectionHandler(final ClientConnectionManager clientConnectionManager,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) {
this.clientConnectionManager = clientConnectionManager;
this.authenticatedGrpcServerAddress = authenticatedGrpcServerAddress;
this.anonymousGrpcServerAddress = anonymousGrpcServerAddress;
}
@@ -41,7 +47,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
}
@Override
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) throws Exception {
public void userEventTriggered(final ChannelHandlerContext remoteChannelContext, final Object event) {
if (event instanceof NoiseHandshakeCompleteEvent noiseHandshakeCompleteEvent) {
// We assume that we'll only get a completed handshake event if the handshake met all authentication requirements
// for the requested service. If the handshake doesn't have an authenticated device, we assume we're trying to
@@ -53,7 +59,6 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
new Bootstrap()
.remoteAddress(grpcServerAddress)
// TODO Set local address
.channel(LocalChannel.class)
.group(remoteChannelContext.channel().eventLoop())
.handler(new ChannelInitializer<LocalChannel>() {
@@ -63,15 +68,19 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
}
})
.connect()
.addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
.addListener((ChannelFutureListener) localChannelFuture -> {
if (localChannelFuture.isSuccess()) {
clientConnectionManager.handleConnectionEstablished((LocalChannel) localChannelFuture.channel(),
remoteChannelContext.channel(),
noiseHandshakeCompleteEvent.authenticatedDevice());
// Close the local connection if the remote channel closes and vice versa
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> future.channel().close());
future.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.channel().closeFuture().addListener(closeFuture -> localChannelFuture.channel().close());
localChannelFuture.channel().closeFuture().addListener(closeFuture ->
remoteChannelContext.write(new CloseWebSocketFrame(WebSocketCloseStatus.SERVICE_RESTART)));
remoteChannelContext.pipeline()
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(future.channel()));
.addAfter(remoteChannelContext.name(), null, new ProxyHandler(localChannelFuture.channel()));
// Flush any buffered reads we accumulated while waiting to open the connection
pendingReads.forEach(remoteChannelContext::fireChannelRead);
@@ -79,7 +88,7 @@ class EstablishLocalGrpcConnectionHandler extends ChannelInboundHandlerAdapter {
remoteChannelContext.pipeline().remove(EstablishLocalGrpcConnectionHandler.this);
} else {
log.warn("Failed to establish local connection to gRPC server", future.cause());
log.warn("Failed to establish local connection to gRPC server", localChannelFuture.cause());
remoteChannelContext.close();
}
});

View File

@@ -0,0 +1,134 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.InetAddresses;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketCloseStatus;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Optional;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A WebSocket handshake handler waits for a WebSocket handshake to complete, then replaces itself with the appropriate
* Noise handshake handler for the requested path.
*/
class WebsocketHandshakeCompleteHandler extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
private final byte[] recognizedProxySecret;
private static final Logger log = LoggerFactory.getLogger(WebsocketHandshakeCompleteHandler.class);
@VisibleForTesting
static final String RECOGNIZED_PROXY_SECRET_HEADER = "X-Signal-Recognized-Proxy";
@VisibleForTesting
static final String FORWARDED_FOR_HEADER = "X-Forwarded-For";
WebsocketHandshakeCompleteHandler(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature,
final String recognizedProxySecret) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
// The recognized proxy secret is an arbitrary string, and not an encoded byte sequence (i.e. a base64- or hex-
// encoded value). We convert it into a byte array here for easier constant-time comparisons via
// MessageDigest.equals() later.
this.recognizedProxySecret = recognizedProxySecret.getBytes(StandardCharsets.UTF_8);
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final InetAddress preferredRemoteAddress;
{
final Optional<InetAddress> maybePreferredRemoteAddress =
getPreferredRemoteAddress(context, handshakeCompleteEvent);
if (maybePreferredRemoteAddress.isEmpty()) {
context.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR,
"Could not determine remote address"))
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
return;
}
preferredRemoteAddress = maybePreferredRemoteAddress.get();
}
ClientConnectionManager.handleWebSocketHandshakeComplete(context.channel(),
preferredRemoteAddress,
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.USER_AGENT),
handshakeCompleteEvent.requestHeaders().getAsString(HttpHeaderNames.ACCEPT_LANGUAGE));
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
default -> {
// The WebSocketOpeningHandshakeHandler should have caught all of these cases already; we'll consider it an
// internal error if something slipped through.
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
}
};
context.pipeline().replace(WebsocketHandshakeCompleteHandler.this, null, noiseHandshakeHandler);
}
context.fireUserEventTriggered(event);
}
private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerContext context,
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final byte[] recognizedProxySecretFromHeader =
handshakeCompleteEvent.requestHeaders().get(RECOGNIZED_PROXY_SECRET_HEADER, "")
.getBytes(StandardCharsets.UTF_8);
final boolean trustForwardedFor = MessageDigest.isEqual(recognizedProxySecret, recognizedProxySecretFromHeader);
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);
return RemoteAddressFilter.getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
try {
return InetAddresses.forString(mostRecentProxy);
} catch (final IllegalArgumentException e) {
log.warn("Failed to parse forwarded-for address: {}", forwardedFor, e);
return null;
}
});
} else {
// Either we don't trust the forwarded-for header or it's not present
if (context.channel().remoteAddress() instanceof InetSocketAddress inetSocketAddress) {
return Optional.of(inetSocketAddress.getAddress());
} else {
log.warn("Channel's remote address was not an InetSocketAddress");
return Optional.empty();
}
}
}
}

View File

@@ -1,52 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
/**
* A WebSocket handshake listener waits for a WebSocket handshake to complete, then replaces itself with the appropriate
* Noise handshake handler for the requested path.
*/
class WebsocketHandshakeCompleteListener extends ChannelInboundHandlerAdapter {
private final ClientPublicKeysManager clientPublicKeysManager;
private final ECKeyPair ecKeyPair;
private final byte[] publicKeySignature;
WebsocketHandshakeCompleteListener(final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature) {
this.clientPublicKeysManager = clientPublicKeysManager;
this.ecKeyPair = ecKeyPair;
this.publicKeySignature = publicKeySignature;
}
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
if (event instanceof WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent) {
final ChannelHandler noiseHandshakeHandler = switch (handshakeCompleteEvent.requestUri()) {
case WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH ->
new NoiseXXHandshakeHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature);
case WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH ->
new NoiseNXHandshakeHandler(ecKeyPair, publicKeySignature);
default -> {
// The HttpHandler should have caught all of these cases already; we'll consider it an internal error if
// something slipped through.
throw new IllegalArgumentException("Unexpected URI: " + handshakeCompleteEvent.requestUri());
}
};
context.pipeline().replace(WebsocketHandshakeCompleteListener.this, null, noiseHandshakeHandler);
}
context.fireUserEventTriggered(event);
}
}

View File

@@ -48,11 +48,13 @@ public class WebsocketNoiseTunnelServer implements Managed {
final PrivateKey tlsPrivateKey,
final NioEventLoopGroup eventLoopGroup,
final Executor delegatedTaskExecutor,
final ClientConnectionManager clientConnectionManager,
final ClientPublicKeysManager clientPublicKeysManager,
final ECKeyPair ecKeyPair,
final byte[] publicKeySignature,
final LocalAddress authenticatedGrpcServerAddress,
final LocalAddress anonymousGrpcServerAddress) throws SSLException {
final LocalAddress anonymousGrpcServerAddress,
final String recognizedProxySecret) throws SSLException {
final SslProvider sslProvider;
@@ -88,10 +90,10 @@ public class WebsocketNoiseTunnelServer implements Managed {
.addLast(new RejectUnsupportedMessagesHandler())
// The WebSocket handshake complete listener will replace itself with an appropriate Noise handshake handler once
// a WebSocket handshake has been completed
.addLast(new WebsocketHandshakeCompleteListener(clientPublicKeysManager, ecKeyPair, publicKeySignature))
.addLast(new WebsocketHandshakeCompleteHandler(clientPublicKeysManager, ecKeyPair, publicKeySignature, recognizedProxySecret))
// This handler will open a local connection to the appropriate gRPC server and install a ProxyHandler
// once the Noise handshake has completed
.addLast(new EstablishLocalGrpcConnectionHandler(authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
.addLast(new EstablishLocalGrpcConnectionHandler(clientConnectionManager, authenticatedGrpcServerAddress, anonymousGrpcServerAddress))
.addLast(new ErrorHandler());
}
});

View File

@@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.util.ua;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver;
import io.grpc.Context;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
public class UserAgentUtil {
public static final Context.Key<UserAgent> USER_AGENT_CONTEXT_KEY = Context.key("x-signal-user-agent");
private static final Pattern STANDARD_UA_PATTERN = Pattern.compile("^Signal-(Android|Desktop|iOS)/([^ ]+)( (.+))?$", Pattern.CASE_INSENSITIVE);
public static UserAgent parseUserAgentString(final String userAgentString) throws UnrecognizedUserAgentException {
@@ -36,10 +33,6 @@ public class UserAgentUtil {
throw new UnrecognizedUserAgentException();
}
public static UserAgent userAgentFromGrpcContext() {
return USER_AGENT_CONTEXT_KEY.get();
}
@VisibleForTesting
static UserAgent parseStandardUserAgentString(final String userAgentString) {
final Matcher matcher = STANDARD_UA_PATTERN.matcher(userAgentString);