Introduce a Noise-over-WebSocket client connection manager

This commit is contained in:
Jon Chambers
2024-03-22 15:20:55 -04:00
committed by GitHub
parent 075a08884b
commit aec6ac019f
53 changed files with 1818 additions and 933 deletions

View File

@@ -0,0 +1,77 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.mockito.Mockito.mock;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
abstract class AbstractAuthenticationInterceptorTest {
private static DefaultEventLoopGroup eventLoopGroup;
private ClientConnectionManager clientConnectionManager;
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-authentication-interceptor-server");
clientConnectionManager = mock(ClientConnectionManager.class);
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(getInterceptor())
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
protected abstract AbstractAuthenticationInterceptor getInterceptor();
protected ClientConnectionManager getClientConnectionManager() {
return clientConnectionManager;
}
protected GetAuthenticatedDeviceResponse getAuthenticatedDevice() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
}
}

View File

@@ -1,135 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import java.io.IOException;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
class BasicCredentialAuthenticationInterceptorTest {
private Server server;
private ManagedChannel managedChannel;
private AccountAuthenticator accountAuthenticator;
@BeforeEach
void setUp() throws IOException {
accountAuthenticator = mock(AccountAuthenticator.class);
final BasicCredentialAuthenticationInterceptor authenticationInterceptor =
new BasicCredentialAuthenticationInterceptor(accountAuthenticator);
final String serverName = InProcessServerBuilder.generateName();
server = InProcessServerBuilder.forName(serverName)
.directExecutor()
.intercept(authenticationInterceptor)
.addService(new EchoServiceImpl())
.build()
.start();
managedChannel = InProcessChannelBuilder.forName(serverName)
.directExecutor()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@ParameterizedTest
@MethodSource
void interceptCall(final Metadata headers, final boolean acceptCredentials, final boolean expectAuthentication) {
if (acceptCredentials) {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
} else {
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.empty());
}
final EchoServiceGrpc.EchoServiceBlockingStub stub = EchoServiceGrpc.newBlockingStub(managedChannel)
.withCallCredentials(new CallCredentials() {
@Override
public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, final MetadataApplier applier) {
applier.apply(headers);
}
@Override
public void thisUsesUnstableApi() {
}
});
if (expectAuthentication) {
assertDoesNotThrow(() -> stub.echo(EchoRequest.newBuilder().build()));
} else {
final StatusRuntimeException exception =
assertThrows(StatusRuntimeException.class, () -> stub.echo(EchoRequest.newBuilder().build()));
assertEquals(Status.UNAUTHENTICATED.getCode(), exception.getStatus().getCode());
}
}
private static Stream<Arguments> interceptCall() {
final Metadata malformedCredentialHeaders = new Metadata();
malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect");
final Metadata structurallyValidCredentialHeaders = new Metadata();
structurallyValidCredentialHeaders.put(
BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS,
HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16))
);
return Stream.of(
Arguments.of(new Metadata(), true, false),
Arguments.of(malformedCredentialHeaders, true, false),
Arguments.of(structurallyValidCredentialHeaders, false, false),
Arguments.of(structurallyValidCredentialHeaders, true, true)
);
}
}

View File

@@ -13,15 +13,14 @@ import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.util.UUID;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.Pair;
public class MockAuthenticationInterceptor implements ServerInterceptor {
@Nullable
private Pair<UUID, Byte> authenticatedDevice;
private AuthenticatedDevice authenticatedDevice;
public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) {
authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
authenticatedDevice = new AuthenticatedDevice(accountIdentifier, deviceId);
}
public void clearAuthenticatedDevice() {
@@ -33,14 +32,10 @@ public class MockAuthenticationInterceptor implements ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
if (authenticatedDevice != null) {
final Context context = Context.current()
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first())
.withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second());
return Contexts.interceptCall(context, call, headers, next);
}
return next.startCall(call, headers);
return authenticatedDevice != null
? Contexts.interceptCall(
Context.current().withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE, authenticatedDevice),
call, headers, next)
: next.startCall(call, headers);
}
}

View File

@@ -0,0 +1,40 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new ProhibitAuthenticationInterceptor(getClientConnectionManager());
}
@Test
void interceptCall() {
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
}
}

View File

@@ -0,0 +1,39 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import io.grpc.Status;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterceptorTest {
@Override
protected AbstractAuthenticationInterceptor getInterceptor() {
return new RequireAuthenticationInterceptor(getClientConnectionManager());
}
@Test
void interceptCall() {
final ClientConnectionManager clientConnectionManager = getClientConnectionManager();
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(clientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
}
}

View File

@@ -39,10 +39,12 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest {
@@ -126,17 +128,25 @@ class RemoteDeprecationFilterTest {
@ParameterizedTest
@MethodSource(value="testFilter")
void testGrpcFilter(final String userAgent, final boolean expectDeprecation) throws Exception {
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.addService(new EchoServiceImpl())
.intercept(filterConfiguredForTest())
.intercept(new UserAgentInterceptor())
.intercept(mockRequestAttributesInterceptor)
.build()
.start();
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.userAgent(userAgent)
.userAgent(userAgentString)
.build();
try {

View File

@@ -1,79 +0,0 @@
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class AcceptLanguageInterceptorTest {
@ParameterizedTest
@MethodSource
void parseLocale(final String header, final List<Locale> expectedLocales) throws IOException, InterruptedException {
final AtomicReference<List<Locale>> observedLocales = new AtomicReference<>(null);
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
@Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
observedLocales.set(AcceptLanguageUtil.localeFromGrpcContext());
super.echo(req, responseObserver);
}
};
final Server testServer = InProcessServerBuilder.forName("AcceptLanguageTest")
.directExecutor()
.addService(serviceImpl)
.intercept(new AcceptLanguageInterceptor())
.intercept(new UserAgentInterceptor())
.build()
.start();
try {
final ManagedChannel channel = InProcessChannelBuilder.forName("AcceptLanguageTest")
.directExecutor()
.userAgent("Signal-Android/1.2.3")
.build();
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, header);
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel)
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
final EchoRequest request = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("test request")).build();
client.echo(request);
assertEquals(expectedLocales, observedLocales.get());
} finally {
testServer.shutdownNow();
testServer.awaitTermination();
}
}
private static Stream<Arguments> parseLocale() {
return Stream.of(
// en-US-POSIX is a special locale that exists alongside en-US. It matches because of the definition of
// basic filtering in RFC 4647 (https://datatracker.ietf.org/doc/html/rfc4647#section-3.3.1)
Arguments.of("en-US,fr-CA", List.of(Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"), Locale.forLanguageTag("fr-CA"))),
Arguments.of("en-US; q=0.9, fr-CA", List.of(Locale.forLanguageTag("fr-CA"), Locale.forLanguageTag("en-US-POSIX"), Locale.forLanguageTag("en-US"))),
Arguments.of("invalid-locale,fr-CA", List.of(Locale.forLanguageTag("fr-CA"))),
Arguments.of("", Collections.emptyList()),
Arguments.of("acompletely,unexpectedfor , mat", Collections.emptyList())
);
}
}

View File

@@ -13,9 +13,9 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.net.InetSocketAddress;
import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
@@ -72,7 +72,7 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRemoteAddressInterceptor().setRemoteAddress(new InetSocketAddress("127.0.0.1", 12345));
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
}

View File

@@ -29,21 +29,21 @@ public final class GrpcTestUtils {
public static void setupAuthenticatedExtension(
final GrpcServerExtension extension,
final MockAuthenticationInterceptor mockAuthenticationInterceptor,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
final UUID authenticatedAci,
final byte authenticatedDeviceId,
final BindableService service) {
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
.addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, mockAuthenticationInterceptor, new ErrorMappingInterceptor()));
}
public static void setupUnauthenticatedExtension(
final GrpcServerExtension extension,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor,
final BindableService service) {
extension.getServiceRegistry()
.addService(ServerInterceptors.intercept(service, mockRemoteAddressInterceptor, new ErrorMappingInterceptor()));
.addService(ServerInterceptors.intercept(service, mockRequestAttributesInterceptor, new ErrorMappingInterceptor()));
}
public static void assertStatusException(final Status expected, final Executable serviceCall) {

View File

@@ -1,37 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.SocketAddress;
import javax.annotation.Nullable;
public class MockRemoteAddressInterceptor implements ServerInterceptor {
@Nullable
private SocketAddress remoteAddress;
public void setRemoteAddress(@Nullable final SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
return remoteAddress == null
? next.startCall(serverCall, headers)
: Contexts.interceptCall(
Context.current().withValue(RemoteAddressUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress),
serverCall, headers, next);
}
}

View File

@@ -0,0 +1,64 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable
private InetAddress remoteAddress;
@Nullable
private UserAgent userAgent;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(final ServerCall<ReqT, RespT> serverCall,
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current();
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
}
}

View File

@@ -17,16 +17,13 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.stub.MetadataUtils;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
@@ -79,6 +76,8 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@@ -100,6 +99,14 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
return new ProfileAnonymousGrpcService(
accountsManager,
profilesManager,
@@ -108,14 +115,6 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
);
}
@Override
protected ProfileAnonymousGrpc.ProfileAnonymousBlockingStub createStub(final Channel channel) throws ClassNotFoundException, InvocationTargetException, NoSuchMethodException, IllegalAccessException {
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
return super.createStub(channel).withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
}
@Test
void getUnversionedProfile() {
final UUID targetUuid = UUID.randomUUID();

View File

@@ -25,7 +25,6 @@ import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusEx
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
import io.grpc.Metadata;
import io.grpc.Status;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
@@ -34,6 +33,7 @@ import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
@@ -103,6 +103,8 @@ import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@@ -167,9 +169,14 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
final String phoneNumber = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final Metadata metadata = new Metadata();
metadata.put(AcceptLanguageInterceptor.ACCEPTABLE_LANGUAGES_GRPC_HEADER, "en-us");
metadata.put(UserAgentInterceptor.USER_AGENT_GRPC_HEADER, "Signal-Android/1.2.3");
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());

View File

@@ -0,0 +1,57 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestAttributesImplBase {
@Override
public void getRequestAttributes(final GetRequestAttributesRequest request,
final StreamObserver<GetRequestAttributesResponse> responseObserver) {
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
.setPlatform(userAgent.getPlatform().toString())
.setVersion(userAgent.getVersion().toString())
.setAdditionalSpecifiers(userAgent.getAdditionalSpecifiers().orElse(""))
.build()));
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();
}
@Override
public void getAuthenticatedDevice(final GetAuthenticatedDeviceRequest request,
final StreamObserver<GetAuthenticatedDeviceResponse> responseObserver) {
final GetAuthenticatedDeviceResponse.Builder responseBuilder = GetAuthenticatedDeviceResponse.newBuilder();
try {
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
responseBuilder.setAccountIdentifier(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()));
responseBuilder.setDeviceId(authenticatedDevice.deviceId());
} catch (final Exception ignored) {
}
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();
}
}

View File

@@ -0,0 +1,160 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.ClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup;
private ClientConnectionManager clientConnectionManager;
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
clientConnectionManager = mock(ClientConnectionManager.class);
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@Test
void getAcceptableLanguages() {
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
}
@Test
void getAvailableAcceptedLocales() {
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(clientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
}
@Test
void getRemoteAddress() {
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
final String remoteAddressString = "6.7.8.9";
when(clientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
when(clientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.empty());
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(clientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
}
private GetRequestAttributesResponse getRequestAttributes() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
}
}

View File

@@ -60,7 +60,7 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
private AutoCloseable mocksCloseable;
private final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
private final MockRemoteAddressInterceptor mockRemoteAddressInterceptor = new MockRemoteAddressInterceptor();
private final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
private SERVICE service;
@@ -114,8 +114,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
mocksCloseable = MockitoAnnotations.openMocks(this);
service = requireNonNull(createServiceBeforeEachTest(), "created service must not be `null`");
GrpcTestUtils.setupAuthenticatedExtension(
GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRemoteAddressInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRemoteAddressInterceptor, service);
GRPC_SERVER_EXTENSION_AUTHENTICATED, mockAuthenticationInterceptor, mockRequestAttributesInterceptor, AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, service);
GrpcTestUtils.setupUnauthenticatedExtension(GRPC_SERVER_EXTENSION_UNAUTHENTICATED, mockRequestAttributesInterceptor, service);
try {
authenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_AUTHENTICATED.getChannel());
unauthenticatedServiceStub = createStub(GRPC_SERVER_EXTENSION_UNAUTHENTICATED.getChannel());
@@ -145,8 +145,8 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
return unauthenticatedServiceStub;
}
protected MockRemoteAddressInterceptor getMockRemoteAddressInterceptor() {
return mockRemoteAddressInterceptor;
protected MockRequestAttributesInterceptor getMockRequestAttributesInterceptor() {
return mockRequestAttributesInterceptor;
}
protected MockAuthenticationInterceptor getMockAuthenticationInterceptor() {

View File

@@ -1,90 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;
public class UserAgentInterceptorTest {
@ParameterizedTest
@MethodSource
void testInterceptor(final String header, final ClientPlatform platform, final String version) throws Exception {
final AtomicReference<UserAgent> observedUserAgent = new AtomicReference<>(null);
final EchoServiceImpl serviceImpl = new EchoServiceImpl() {
@Override
public void echo(EchoRequest req, StreamObserver<EchoResponse> responseObserver) {
observedUserAgent.set(UserAgentUtil.userAgentFromGrpcContext());
super.echo(req, responseObserver);
}
};
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.addService(serviceImpl)
.intercept(new UserAgentInterceptor())
.build()
.start();
try {
final ManagedChannel channel = InProcessChannelBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()
.userAgent(header)
.build();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
final EchoRequest req = EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("cluck cluck, i'm a parrot")).build();
assertEquals("cluck cluck, i'm a parrot", client.echo(req).getPayload().toStringUtf8());
if (platform == null) {
assertNull(observedUserAgent.get());
} else {
assertEquals(platform, observedUserAgent.get().getPlatform());
assertEquals(new Semver(version), observedUserAgent.get().getVersion());
// can't assert on the additional specifiers because they include internal details of the grpc in-process channel itself
}
} finally {
testServer.shutdownNow();
testServer.awaitTermination();
}
}
private static Stream<Arguments> testInterceptor() {
return Stream.of(
Arguments.of(null, null, null),
Arguments.of("", null, null),
Arguments.of("Unrecognized UA", null, null),
Arguments.of("Signal-Android/4.68.3", ClientPlatform.ANDROID, "4.68.3"),
Arguments.of("Signal-iOS/3.9.0", ClientPlatform.IOS, "3.9.0"),
Arguments.of("Signal-Desktop/1.2.3", ClientPlatform.DESKTOP, "1.2.3"),
Arguments.of("Signal-Desktop/8.0.0-beta.2", ClientPlatform.DESKTOP, "8.0.0-beta.2"),
Arguments.of("Signal-iOS/8.0.0-beta.2", ClientPlatform.IOS, "8.0.0-beta.2"));
}
}

View File

@@ -1,21 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.grpc.stub.StreamObserver;
import org.signal.chat.rpc.AuthenticationTypeGrpc;
import org.signal.chat.rpc.GetAuthenticatedRequest;
import org.signal.chat.rpc.GetAuthenticatedResponse;
public class AuthenticationTypeService extends AuthenticationTypeGrpc.AuthenticationTypeImplBase {
private final boolean authenticated;
public AuthenticationTypeService(final boolean authenticated) {
this.authenticated = authenticated;
}
@Override
public void getAuthenticated(final GetAuthenticatedRequest request, final StreamObserver<GetAuthenticatedResponse> responseObserver) {
responseObserver.onNext(GetAuthenticatedResponse.newBuilder().setAuthenticated(authenticated).build());
responseObserver.onCompleted();
}
}

View File

@@ -0,0 +1,283 @@
package org.whispersystems.textsecuregcm.grpc.net;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class ClientConnectionManagerTest {
private static EventLoopGroup eventLoopGroup;
private LocalChannel localChannel;
private LocalChannel remoteChannel;
private LocalServerChannel localServerChannel;
private ClientConnectionManager clientConnectionManager;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws InterruptedException {
eventLoopGroup = new DefaultEventLoopGroup();
clientConnectionManager = new ClientConnectionManager();
// We have to jump through some hoops to get "real" LocalChannel instances to test with, and so we run a trivial
// local server to which we can open trivial local connections
localServerChannel = (LocalServerChannel) new ServerBootstrap()
.group(eventLoopGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel channel) {
}
})
.bind(new LocalAddress("test-server"))
.await()
.channel();
final Bootstrap clientBootstrap = new Bootstrap()
.group(eventLoopGroup)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<>() {
@Override
protected void initChannel(final Channel ch) {
}
});
localChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
remoteChannel = (LocalChannel) clientBootstrap.connect(localServerChannel.localAddress()).await().channel();
}
@AfterEach
void tearDown() throws InterruptedException {
localChannel.close().await();
remoteChannel.close().await();
localServerChannel.close().await();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
}
@ParameterizedTest
@MethodSource
void getAuthenticatedDevice(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<AuthenticatedDevice> maybeAuthenticatedDevice) {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice,
clientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
}
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
return List.of(
Optional.of(new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID)),
Optional.empty()
);
}
@Test
void getAcceptableLanguages() {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
remoteChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
assertEquals(Optional.of(acceptLanguageRanges),
clientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
}
@Test
void getRemoteAddress() {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
clientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
clientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
clientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel),
clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleWebSocketHandshakeCompleteRemoteAddress() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(ClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(ClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
ClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader);
assertEquals(expectedLanguageRanges,
embeddedChannel.attr(ClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
return List.of(
// Parseable list
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
// Unparsable list
Arguments.of("This is not a valid language preference list", null),
// Missing list
Arguments.of(null, null)
);
}
@Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertNull(clientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleConnectionEstablishedAnonymous() throws InterruptedException {
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
clientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
remoteChannel.close().await();
assertNull(clientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
}
}

View File

@@ -8,8 +8,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
@@ -35,6 +35,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
private final ECPublicKey rootPublicKey;
@Nullable private final UUID accountIdentifier;
private final byte deviceId;
private final HttpHeaders headers;
private final SocketAddress remoteServerAddress;
private final WebSocketCloseListener webSocketCloseListener;
@@ -50,6 +51,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final SocketAddress remoteServerAddress,
final WebSocketCloseListener webSocketCloseListener) {
@@ -60,6 +62,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
this.rootPublicKey = rootPublicKey;
this.accountIdentifier = accountIdentifier;
this.deviceId = deviceId;
this.headers = headers;
this.remoteServerAddress = remoteServerAddress;
this.webSocketCloseListener = webSocketCloseListener;
}
@@ -87,7 +90,7 @@ class EstablishRemoteConnectionHandler extends ChannelInboundHandlerAdapter {
WebSocketVersion.V13,
null,
false,
new DefaultHttpHeaders(),
headers,
Noise.MAX_PACKET_LEN,
10_000))
.addLast(new OutboundCloseWebSocketFrameHandler(webSocketCloseListener))

View File

@@ -11,6 +11,7 @@ import java.net.SocketAddress;
import java.net.URI;
import java.security.cert.X509Certificate;
import java.util.UUID;
import io.netty.handler.codec.http.HttpHeaders;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import javax.annotation.Nullable;
@@ -30,6 +31,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
final ECPublicKey rootPublicKey,
@Nullable final UUID accountIdentifier,
final byte deviceId,
final HttpHeaders headers,
final X509Certificate trustedServerCertificate,
final NioEventLoopGroup eventLoopGroup,
final WebSocketCloseListener webSocketCloseListener) {
@@ -48,6 +50,7 @@ class WebSocketNoiseTunnelClient implements AutoCloseable {
rootPublicKey,
accountIdentifier,
deviceId,
headers,
remoteServerAddress,
webSocketCloseListener));
}

View File

@@ -1,7 +1,6 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
@@ -17,6 +16,8 @@ import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
@@ -35,28 +36,41 @@ import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.AuthenticationTypeGrpc;
import org.signal.chat.rpc.GetAuthenticatedRequest;
import org.signal.chat.rpc.GetAuthenticatedResponse;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributesServiceImpl;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTest {
@@ -66,6 +80,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static X509Certificate serverTlsCertificate;
private ClientConnectionManager clientConnectionManager;
private ClientPublicKeysManager clientPublicKeysManager;
private ECKeyPair rootKeyPair;
@@ -79,6 +94,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private static final UUID ACCOUNT_IDENTIFIER = UUID.randomUUID();
private static final byte DEVICE_ID = Device.PRIMARY_ID;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
// Please note that this certificate/key are used only for testing and are not used anywhere outside of this test.
// They were generated with:
//
@@ -133,6 +150,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
clientKeyPair = Curve.generateKeyPair();
final ECKeyPair serverKeyPair = Curve.generateKeyPair();
clientConnectionManager = new ClientConnectionManager();
clientPublicKeysManager = mock(ClientPublicKeysManager.class);
when(clientPublicKeysManager.findPublicKey(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@@ -146,7 +165,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
authenticatedGrpcServer = new ManagedLocalGrpcServer(authenticatedGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(true));
serverBuilder.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.intercept(new RequireAuthenticationInterceptor(clientConnectionManager));
}
};
@@ -155,7 +176,9 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
anonymousGrpcServer = new ManagedLocalGrpcServer(anonymousGrpcServerAddress, defaultEventLoopGroup) {
@Override
protected void configureServer(final ServerBuilder<?> serverBuilder) {
serverBuilder.addService(new AuthenticationTypeService(false));
serverBuilder.addService(new RequestAttributesServiceImpl())
.intercept(new RequestAttributesInterceptor(clientConnectionManager))
.intercept(new ProhibitAuthenticationInterceptor(clientConnectionManager));
}
};
@@ -166,11 +189,13 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
serverTlsPrivateKey,
nioEventLoopGroup,
delegatedTaskExecutor,
clientConnectionManager,
clientPublicKeysManager,
serverKeyPair,
rootKeyPair.getPrivateKey().calculateSignature(serverKeyPair.getPublicKey().getPublicKeyBytes()),
authenticatedGrpcServerAddress,
anonymousGrpcServerAddress);
anonymousGrpcServerAddress,
RECOGNIZED_PROXY_SECRET);
websocketNoiseTunnelServer.start();
}
@@ -198,10 +223,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertTrue(response.getAuthenticated());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
} finally {
channel.shutdown();
}
@@ -215,15 +241,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
// Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
buildAndStartAuthenticatedClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -247,8 +273,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -272,8 +298,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -294,6 +320,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootKeyPair.getPublicKey(),
ACCOUNT_IDENTIFIER,
DEVICE_ID,
new DefaultHttpHeaders(),
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -304,8 +331,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -320,10 +347,11 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedResponse response = AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build());
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertFalse(response.getAuthenticated());
assertTrue(response.getAccountIdentifier().isEmpty());
assertEquals(0, response.getDeviceId());
} finally {
channel.shutdown();
}
@@ -336,15 +364,15 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
// Try to verify the server's public key with something other than the key with which it was signed
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey())) {
buildAndStartAnonymousClient(webSocketCloseListener, Curve.generateKeyPair().getPublicKey(), new DefaultHttpHeaders())) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -365,6 +393,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootKeyPair.getPublicKey(),
null,
(byte) 0,
new DefaultHttpHeaders(),
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -375,8 +404,8 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
try {
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE,
() -> AuthenticationTypeGrpc.newBlockingStub(channel)
.getAuthenticated(GetAuthenticatedRequest.newBuilder().build()));
() -> RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build()));
} finally {
channel.shutdown();
}
@@ -438,6 +467,86 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
}
@Test
void getRequestAttributes() throws InterruptedException {
final String remoteAddress = "4.5.6.7";
final String acceptLanguage = "en";
final String userAgent = "Signal-Desktop/1.2.3 Linux";
final HttpHeaders headers = new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add("X-Forwarded-For", remoteAddress)
.add("Accept-Language", acceptLanguage)
.add("User-Agent", userAgent);
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient =
buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), headers)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetRequestAttributesResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
} finally {
channel.shutdown();
}
}
}
@Test
void closeForReauthentication() throws InterruptedException {
final CountDownLatch connectionCloseLatch = new CountDownLatch(1);
final AtomicInteger serverCloseStatusCode = new AtomicInteger(0);
final AtomicBoolean closedByServer = new AtomicBoolean(false);
final WebSocketCloseListener webSocketCloseListener = new WebSocketCloseListener() {
@Override
public void handleWebSocketClosedByClient(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(false);
connectionCloseLatch.countDown();
}
@Override
public void handleWebSocketClosedByServer(final int statusCode) {
serverCloseStatusCode.set(statusCode);
closedByServer.set(true);
connectionCloseLatch.countDown();
}
};
try (final WebSocketNoiseTunnelClient webSocketNoiseTunnelClient = buildAndStartAuthenticatedClient(webSocketCloseListener)) {
final ManagedChannel channel = buildManagedChannel(webSocketNoiseTunnelClient.getLocalAddress());
try {
final GetAuthenticatedDeviceResponse response = RequestAttributesGrpc.newBlockingStub(channel)
.getAuthenticatedDevice(GetAuthenticatedDeviceRequest.newBuilder().build());
assertEquals(UUIDUtil.toByteString(ACCOUNT_IDENTIFIER), response.getAccountIdentifier());
assertEquals(DEVICE_ID, response.getDeviceId());
clientConnectionManager.closeConnection(new AuthenticatedDevice(ACCOUNT_IDENTIFIER, DEVICE_ID));
assertTrue(connectionCloseLatch.await(2, TimeUnit.SECONDS));
assertEquals(ApplicationWebSocketCloseReason.REAUTHENTICATION_REQUIRED.getStatusCode(),
serverCloseStatusCode.get());
assertTrue(closedByServer.get());
} finally {
channel.shutdown();
}
}
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient() throws InterruptedException {
return buildAndStartAuthenticatedClient(WebSocketCloseListener.NOOP_LISTENER);
}
@@ -445,11 +554,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener)
throws InterruptedException {
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey());
return buildAndStartAuthenticatedClient(webSocketCloseListener, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
}
private WebSocketNoiseTunnelClient buildAndStartAuthenticatedClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException {
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.AUTHENTICATED_WEBSOCKET_URI,
@@ -458,6 +568,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootPublicKey,
ACCOUNT_IDENTIFIER,
DEVICE_ID,
headers,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)
@@ -465,11 +576,12 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
}
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient() throws InterruptedException {
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey());
return buildAndStartAnonymousClient(WebSocketCloseListener.NOOP_LISTENER, rootKeyPair.getPublicKey(), new DefaultHttpHeaders());
}
private WebSocketNoiseTunnelClient buildAndStartAnonymousClient(final WebSocketCloseListener webSocketCloseListener,
final ECPublicKey rootPublicKey) throws InterruptedException {
final ECPublicKey rootPublicKey,
final HttpHeaders headers) throws InterruptedException {
return new WebSocketNoiseTunnelClient(websocketNoiseTunnelServer.getLocalAddress(),
WebSocketNoiseTunnelClient.ANONYMOUS_WEBSOCKET_URI,
@@ -478,6 +590,7 @@ class WebSocketNoiseTunnelServerIntegrationTest extends AbstractLeakDetectionTes
rootPublicKey,
null,
(byte) 0,
headers,
serverTlsCertificate,
nioEventLoopGroup,
webSocketCloseListener)

View File

@@ -0,0 +1,202 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import javax.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private MutableRemoteAddressEmbeddedChannel embeddedChannel;
private static final String RECOGNIZED_PROXY_SECRET = RandomStringUtils.randomAlphanumeric(16);
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
private static class MutableRemoteAddressEmbeddedChannel extends EmbeddedChannel {
private SocketAddress remoteAddress;
public MutableRemoteAddressEmbeddedChannel(final ChannelHandler... handlers) {
super(handlers);
}
@Override
protected SocketAddress remoteAddress0() {
return isActive() ? remoteAddress : null;
}
public void setRemoteAddress(final SocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new MutableRemoteAddressEmbeddedChannel(
new WebsocketHandshakeCompleteHandler(mock(ClientPublicKeysManager.class),
Curve.generateKeyPair(),
new byte[64],
RECOGNIZED_PROXY_SECRET),
userEventRecordingHandler);
embeddedChannel.setRemoteAddress(new InetSocketAddress("127.0.0.1", 0));
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteHandler.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
@ParameterizedTest
@MethodSource
void getRemoteAddress(final HttpHeaders headers, final SocketAddress remoteAddress, @Nullable InetAddress expectedRemoteAddress) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(
WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, headers, null);
embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress,
embeddedChannel.attr(ClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
private static List<Arguments> getRemoteAddress() {
final InetSocketAddress remoteAddress = new InetSocketAddress("5.6.7.8", 0);
final InetAddress clientAddress = InetAddresses.forString("1.2.3.4");
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of(
// Recognized proxy, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
clientAddress),
// Recognized proxy, multiple forwarded-for addresses
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress,
proxyAddress),
// No recognized proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address
Arguments.of(new DefaultHttpHeaders(),
remoteAddress,
remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress,
null),
// No forwarded-for address, non-InetSocketAddress remote address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"),
null)
);
}
}

View File

@@ -1,91 +0,0 @@
package org.whispersystems.textsecuregcm.grpc.net;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
class WebsocketHandshakeCompleteListenerTest extends AbstractLeakDetectionTest {
private UserEventRecordingHandler userEventRecordingHandler;
private EmbeddedChannel embeddedChannel;
private static class UserEventRecordingHandler extends ChannelInboundHandlerAdapter {
private final List<Object> receivedEvents = new ArrayList<>();
@Override
public void userEventTriggered(final ChannelHandlerContext context, final Object event) {
receivedEvents.add(event);
}
public List<Object> getReceivedEvents() {
return receivedEvents;
}
}
@BeforeEach
void setUp() {
userEventRecordingHandler = new UserEventRecordingHandler();
embeddedChannel = new EmbeddedChannel(
new WebsocketHandshakeCompleteListener(mock(ClientPublicKeysManager.class), Curve.generateKeyPair(), new byte[64]),
userEventRecordingHandler);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeComplete(final String uri, final Class<? extends AbstractNoiseHandshakeHandler> expectedHandlerClass) {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete(uri, new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertNotNull(embeddedChannel.pipeline().get(expectedHandlerClass));
assertEquals(List.of(handshakeCompleteEvent), userEventRecordingHandler.getReceivedEvents());
}
private static List<Arguments> handleWebSocketHandshakeComplete() {
return List.of(
Arguments.of(WebsocketNoiseTunnelServer.AUTHENTICATED_SERVICE_PATH, NoiseXXHandshakeHandler.class),
Arguments.of(WebsocketNoiseTunnelServer.ANONYMOUS_SERVICE_PATH, NoiseNXHandshakeHandler.class));
}
@Test
void handleWebSocketHandshakeCompleteUnexpectedPath() {
final WebSocketServerProtocolHandler.HandshakeComplete handshakeCompleteEvent =
new WebSocketServerProtocolHandler.HandshakeComplete("/incorrect", new DefaultHttpHeaders(), null);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertNotNull(embeddedChannel.pipeline().get(WebsocketHandshakeCompleteListener.class));
assertThrows(IllegalArgumentException.class, () -> embeddedChannel.checkException());
}
@Test
void handleUnrecognizedEvent() {
final Object unrecognizedEvent = new Object();
embeddedChannel.pipeline().fireUserEventTriggered(unrecognizedEvent);
assertEquals(List.of(unrecognizedEvent), userEventRecordingHandler.getReceivedEvents());
}
}