Clarify guarantees around remote channnel/request attribute presence

This commit is contained in:
Jon Chambers
2025-04-18 15:44:21 -04:00
committed by GitHub
parent 28a0b9e84e
commit 8099d6465c
28 changed files with 405 additions and 570 deletions

View File

@@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -22,7 +23,7 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
}
@Test
void interceptCall() {
void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
@@ -34,6 +35,10 @@ class ProhibitAuthenticationInterceptorTest extends AbstractAuthenticationInterc
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
}
}

View File

@@ -9,6 +9,7 @@ import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -22,12 +23,12 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
}
@Test
void interceptCall() {
void interceptCall() throws ChannelNotFoundException {
final GrpcClientConnectionManager grpcClientConnectionManager = getClientConnectionManager();
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.empty());
GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, this::getAuthenticatedDevice);
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getAuthenticatedDevice);
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenReturn(Optional.of(authenticatedDevice));
@@ -35,5 +36,9 @@ class RequireAuthenticationInterceptorTest extends AbstractAuthenticationInterce
final GetAuthenticatedDeviceResponse response = getAuthenticatedDevice();
assertEquals(UUIDUtil.toByteString(authenticatedDevice.accountIdentifier()), response.getAccountIdentifier());
assertEquals(authenticatedDevice.deviceId(), response.getDeviceId());
when(grpcClientConnectionManager.getAuthenticatedDevice(any())).thenThrow(ChannelNotFoundException.class);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, this::getAuthenticatedDevice);
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertEquals;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
@@ -24,7 +25,6 @@ import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.core.Response;
import java.net.InetAddress;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Set;
@@ -39,6 +39,7 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcTestUtils;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.util.InetAddressRange;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -157,7 +158,7 @@ class ExternalRequestFilterTest {
@BeforeEach
void setUp() throws Exception {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRemoteAddress(InetAddress.getByName("127.0.0.1"));
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
testServer = InProcessServerBuilder.forName("ExternalRequestFilterTest")
.directExecutor()

View File

@@ -15,6 +15,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel;
@@ -40,11 +41,10 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RemoteDeprecationFilterTest {
@@ -130,11 +130,7 @@ class RemoteDeprecationFilterTest {
@MethodSource(value="testFilter")
void testGrpcFilter(final String userAgentString, final boolean expectDeprecation) throws IOException, InterruptedException {
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
try {
mockRequestAttributesInterceptor.setUserAgent(UserAgentUtil.parseUserAgentString(userAgentString));
} catch (UnrecognizedUserAgentException ignored) {
}
mockRequestAttributesInterceptor.setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"), userAgentString, null));
final Server testServer = InProcessServerBuilder.forName("RemoteDeprecationFilterTest")
.directExecutor()

View File

@@ -72,7 +72,8 @@ class AccountsAnonymousGrpcServiceTest extends
when(rateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
getMockRequestAttributesInterceptor().setRemoteAddress(InetAddresses.forString("127.0.0.1"));
getMockRequestAttributesInterceptor().setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null));
return new AccountsAnonymousGrpcService(accountsManager, rateLimiters);
}

View File

@@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.net.InetAddresses;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.Metadata;
@@ -19,25 +20,10 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgent;
public class MockRequestAttributesInterceptor implements ServerInterceptor {
@Nullable
private InetAddress remoteAddress;
private RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("127.0.0.1"), null, null);
@Nullable
private UserAgent userAgent;
@Nullable
private List<Locale.LanguageRange> acceptLanguage;
public void setRemoteAddress(@Nullable final InetAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public void setUserAgent(@Nullable final UserAgent userAgent) {
this.userAgent = userAgent;
}
public void setAcceptLanguage(@Nullable final List<Locale.LanguageRange> acceptLanguage) {
this.acceptLanguage = acceptLanguage;
public void setRequestAttributes(final RequestAttributes requestAttributes) {
this.requestAttributes = requestAttributes;
}
@Override
@@ -45,20 +31,7 @@ public class MockRequestAttributesInterceptor implements ServerInterceptor {
final Metadata headers,
final ServerCallHandler<ReqT, RespT> next) {
Context context = Context.current();
if (remoteAddress != null) {
context = context.withValue(RequestAttributesUtil.REMOTE_ADDRESS_CONTEXT_KEY, remoteAddress);
}
if (userAgent != null) {
context = context.withValue(RequestAttributesUtil.USER_AGENT_CONTEXT_KEY, userAgent);
}
if (acceptLanguage != null) {
context = context.withValue(RequestAttributesUtil.ACCEPT_LANGUAGE_CONTEXT_KEY, acceptLanguage);
}
return Contexts.interceptCall(context, serverCall, headers, next);
return Contexts.interceptCall(Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes), serverCall, headers, next);
}
}

View File

@@ -15,6 +15,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import java.nio.charset.StandardCharsets;
@@ -75,8 +76,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileAnonymousGrpcService, ProfileAnonymousGrpc.ProfileAnonymousBlockingStub> {
@@ -96,13 +95,9 @@ public class ProfileAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<ProfileA
@Override
protected ProfileAnonymousGrpcService createServiceBeforeEachTest() {
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us")));
return new ProfileAnonymousGrpcService(
accountsManager,

View File

@@ -24,6 +24,7 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.common.net.InetAddresses;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.Phonenumber;
import com.google.protobuf.ByteString;
@@ -108,8 +109,6 @@ import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@@ -177,13 +176,9 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
getMockRequestAttributesInterceptor().setAcceptLanguage(Locale.LanguageRange.parse("en-us"));
try {
getMockRequestAttributesInterceptor().setUserAgent(UserAgentUtil.parseUserAgentString("Signal-Android/1.2.3"));
} catch (final UnrecognizedUserAgentException e) {
throw new IllegalArgumentException(e);
}
getMockRequestAttributesInterceptor().setRequestAttributes(new RequestAttributes(InetAddresses.forString("127.0.0.1"),
"Signal-Android/1.2.3",
Locale.LanguageRange.parse("en-us")));
when(rateLimiters.getProfileLimiter()).thenReturn(rateLimiter);
when(rateLimiter.validateReactive(any(UUID.class))).thenReturn(Mono.empty());

View File

@@ -1,13 +1,11 @@
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.stub.StreamObserver;
import org.apache.commons.lang3.StringUtils;
import org.signal.chat.rpc.GetAuthenticatedDeviceRequest;
import org.signal.chat.rpc.GetAuthenticatedDeviceResponse;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.signal.chat.rpc.UserAgent;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@@ -20,21 +18,15 @@ public class RequestAttributesServiceImpl extends RequestAttributesGrpc.RequestA
final GetRequestAttributesResponse.Builder responseBuilder = GetRequestAttributesResponse.newBuilder();
RequestAttributesUtil.getAcceptableLanguages().ifPresent(acceptableLanguages ->
acceptableLanguages.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString())));
RequestAttributesUtil.getAcceptableLanguages()
.forEach(languageRange -> responseBuilder.addAcceptableLanguages(languageRange.toString()));
RequestAttributesUtil.getAvailableAcceptedLocales().forEach(locale ->
responseBuilder.addAvailableAcceptedLocales(locale.toLanguageTag()));
responseBuilder.setRemoteAddress(RequestAttributesUtil.getRemoteAddress().getHostAddress());
RequestAttributesUtil.getUserAgent().ifPresent(userAgent -> responseBuilder.setUserAgent(UserAgent.newBuilder()
.setPlatform(userAgent.platform().toString())
.setVersion(userAgent.version().toString())
.setAdditionalSpecifiers(StringUtils.stripToEmpty(userAgent.additionalSpecifiers()))
.build()));
RequestAttributesUtil.getRawUserAgent().ifPresent(responseBuilder::setRawUserAgent);
RequestAttributesUtil.getUserAgent().ifPresent(responseBuilder::setUserAgent);
responseObserver.onNext(responseBuilder.build());
responseObserver.onCompleted();

View File

@@ -3,172 +3,84 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.io.IOException;
import io.grpc.Context;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import java.util.concurrent.Callable;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.GetRequestAttributesRequest;
import org.signal.chat.rpc.GetRequestAttributesResponse;
import org.signal.chat.rpc.RequestAttributesGrpc;
import org.whispersystems.textsecuregcm.grpc.net.GrpcClientConnectionManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
class RequestAttributesUtilTest {
private static DefaultEventLoopGroup eventLoopGroup;
private static final InetAddress REMOTE_ADDRESS = InetAddresses.forString("127.0.0.1");
private GrpcClientConnectionManager grpcClientConnectionManager;
@Test
void getAcceptableLanguages() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAcceptableLanguages));
private Server server;
private ManagedChannel managedChannel;
@BeforeAll
static void setUpBeforeAll() {
eventLoopGroup = new DefaultEventLoopGroup();
}
@BeforeEach
void setUp() throws IOException {
final LocalAddress serverAddress = new LocalAddress("test-request-metadata-server");
grpcClientConnectionManager = mock(GrpcClientConnectionManager.class);
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString("127.0.0.1")));
// `RequestAttributesInterceptor` operates on `LocalAddresses`, so we need to do some slightly fancy plumbing to make
// sure that we're using local channels and addresses
server = NettyServerBuilder.forAddress(serverAddress)
.channelType(LocalServerChannel.class)
.bossEventLoopGroup(eventLoopGroup)
.workerEventLoopGroup(eventLoopGroup)
.intercept(new RequestAttributesInterceptor(grpcClientConnectionManager))
.addService(new RequestAttributesServiceImpl())
.build()
.start();
managedChannel = NettyChannelBuilder.forAddress(serverAddress)
.channelType(LocalChannel.class)
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
@AfterEach
void tearDown() {
managedChannel.shutdown();
server.shutdown();
}
@AfterAll
static void tearDownAfterAll() throws InterruptedException {
eventLoopGroup.shutdownGracefully().await();
assertEquals(Locale.LanguageRange.parse("en,ja"),
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAcceptableLanguages));
}
@Test
void getAcceptableLanguages() {
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
void getAvailableAcceptedLocales() throws Exception {
assertEquals(Collections.emptyList(),
callWithRequestAttributes(buildRequestAttributes(Collections.emptyList()),
RequestAttributesUtil::getAvailableAcceptedLocales));
assertTrue(getRequestAttributes().getAcceptableLanguagesList().isEmpty());
final List<Locale> availableAcceptedLocales =
callWithRequestAttributes(buildRequestAttributes(Locale.LanguageRange.parse("en,ja")),
RequestAttributesUtil::getAvailableAcceptedLocales);
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
assertFalse(availableAcceptedLocales.isEmpty());
assertEquals(List.of("en", "ja"), getRequestAttributes().getAcceptableLanguagesList());
availableAcceptedLocales.forEach(locale ->
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage())));
}
@Test
void getAvailableAcceptedLocales() {
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getAvailableAcceptedLocalesList().isEmpty());
when(grpcClientConnectionManager.getAcceptableLanguages(any()))
.thenReturn(Optional.of(Locale.LanguageRange.parse("en,ja")));
final GetRequestAttributesResponse response = getRequestAttributes();
assertFalse(response.getAvailableAcceptedLocalesList().isEmpty());
response.getAvailableAcceptedLocalesList().forEach(languageTag -> {
final Locale locale = Locale.forLanguageTag(languageTag);
assertTrue("en".equals(locale.getLanguage()) || "ja".equals(locale.getLanguage()));
});
void getRemoteAddress() throws Exception {
assertEquals(REMOTE_ADDRESS,
callWithRequestAttributes(new RequestAttributes(REMOTE_ADDRESS, null, null),
RequestAttributesUtil::getRemoteAddress));
}
@Test
void getRemoteAddress() {
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.empty());
void getUserAgent() throws Exception {
assertEquals(Optional.empty(),
callWithRequestAttributes(buildRequestAttributes((String) null),
RequestAttributesUtil::getUserAgent));
GrpcTestUtils.assertStatusException(Status.INTERNAL, this::getRequestAttributes);
final String remoteAddressString = "6.7.8.9";
when(grpcClientConnectionManager.getRemoteAddress(any()))
.thenReturn(Optional.of(InetAddresses.forString(remoteAddressString)));
assertEquals(remoteAddressString, getRequestAttributes().getRemoteAddress());
assertEquals(Optional.of("Signal-Desktop/1.2.3 Linux"),
callWithRequestAttributes(buildRequestAttributes("Signal-Desktop/1.2.3 Linux"),
RequestAttributesUtil::getUserAgent));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.empty());
assertFalse(getRequestAttributes().hasUserAgent());
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
when(grpcClientConnectionManager.getUserAgent(any()))
.thenReturn(Optional.of(userAgent));
final GetRequestAttributesResponse response = getRequestAttributes();
assertTrue(response.hasUserAgent());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
private static <V> V callWithRequestAttributes(final RequestAttributes requestAttributes, final Callable<V> callable) throws Exception {
return Context.current()
.withValue(RequestAttributesUtil.REQUEST_ATTRIBUTES_CONTEXT_KEY, requestAttributes)
.call(callable);
}
@Test
void getRawUserAgent() {
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.empty());
assertTrue(getRequestAttributes().getRawUserAgent().isBlank());
final String userAgentString = "Signal-Desktop/1.2.3 Linux";
when(grpcClientConnectionManager.getRawUserAgent(any()))
.thenReturn(Optional.of(userAgentString));
assertEquals(userAgentString, getRequestAttributes().getRawUserAgent());
private static RequestAttributes buildRequestAttributes(final String userAgent) {
return buildRequestAttributes(userAgent, Collections.emptyList());
}
private GetRequestAttributesResponse getRequestAttributes() {
return RequestAttributesGrpc.newBlockingStub(managedChannel)
.getRequestAttributes(GetRequestAttributesRequest.newBuilder().build());
private static RequestAttributes buildRequestAttributes(final List<Locale.LanguageRange> acceptLanguage) {
return buildRequestAttributes(null, acceptLanguage);
}
private static RequestAttributes buildRequestAttributes(@Nullable final String userAgent,
final List<Locale.LanguageRange> acceptLanguage) {
return new RequestAttributes(REMOTE_ADDRESS, userAgent, acceptLanguage);
}
}

View File

@@ -1,7 +1,11 @@
package org.whispersystems.textsecuregcm.grpc.net;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.net.InetAddresses;
import com.vdurmont.semver4j.Semver;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
@@ -12,6 +16,12 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
@@ -21,20 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.grpc.ChannelNotFoundException;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import javax.annotation.Nullable;
import java.net.InetAddress;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
class GrpcClientConnectionManagerTest {
@@ -103,7 +102,7 @@ class GrpcClientConnectionManagerTest {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, maybeAuthenticatedDevice);
assertEquals(maybeAuthenticatedDevice,
grpcClientConnectionManager.getAuthenticatedDevice(localChannel.localAddress()));
grpcClientConnectionManager.getAuthenticatedDevice(remoteChannel));
}
private static List<Optional<AuthenticatedDevice>> getAuthenticatedDevice() {
@@ -114,170 +113,115 @@ class GrpcClientConnectionManagerTest {
}
@Test
void getAcceptableLanguages() {
void getRequestAttributes() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
assertThrows(IllegalStateException.class, () -> grpcClientConnectionManager.getRequestAttributes(remoteChannel));
final List<Locale.LanguageRange> acceptLanguageRanges = Locale.LanguageRange.parse("en,ja");
remoteChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).set(acceptLanguageRanges);
final RequestAttributes requestAttributes = new RequestAttributes(InetAddresses.forString("6.7.8.9"), null, null);
remoteChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).set(requestAttributes);
assertEquals(Optional.of(acceptLanguageRanges),
grpcClientConnectionManager.getAcceptableLanguages(localChannel.localAddress()));
assertEquals(requestAttributes, grpcClientConnectionManager.getRequestAttributes(remoteChannel));
}
@Test
void getRemoteAddress() {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
final InetAddress remoteAddress = InetAddresses.forString("6.7.8.9");
remoteChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).set(remoteAddress);
assertEquals(Optional.of(remoteAddress),
grpcClientConnectionManager.getRemoteAddress(localChannel.localAddress()));
}
@Test
void getUserAgent() throws UnrecognizedUserAgentException {
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(Optional.empty(),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
final UserAgent userAgent = UserAgentUtil.parseUserAgentString("Signal-Desktop/1.2.3 Linux");
remoteChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).set(userAgent);
assertEquals(Optional.of(userAgent),
grpcClientConnectionManager.getUserAgent(localChannel.localAddress()));
}
@Test
void closeConnection() throws InterruptedException {
void closeConnection() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertTrue(remoteChannel.isOpen());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel),
grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleWebSocketHandshakeCompleteRemoteAddress() {
@ParameterizedTest
@MethodSource
void handleHandshakeCompleteRequestAttributes(final InetAddress preferredRemoteAddress,
final String userAgentHeader,
final String acceptLanguageHeader,
final RequestAttributes expectedRequestAttributes) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
GrpcClientConnectionManager.handleHandshakeComplete(embeddedChannel,
preferredRemoteAddress,
null,
null);
assertEquals(preferredRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteUserAgent(@Nullable final String userAgentHeader,
@Nullable final UserAgent expectedParsedUserAgent) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
userAgentHeader,
null);
assertEquals(userAgentHeader,
embeddedChannel.attr(GrpcClientConnectionManager.RAW_USER_AGENT_ATTRIBUTE_KEY).get());
assertEquals(expectedParsedUserAgent,
embeddedChannel.attr(GrpcClientConnectionManager.PARSED_USER_AGENT_ATTRIBUTE_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteUserAgent() {
return List.of(
// Recognized user-agent
Arguments.of("Signal-Desktop/1.2.3 Linux", new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux")),
// Unrecognized user-agent
Arguments.of("Not a valid user-agent string", null),
// Missing user-agent
Arguments.of(null, null)
);
}
@ParameterizedTest
@MethodSource
void handleWebSocketHandshakeCompleteAcceptLanguage(@Nullable final String acceptLanguageHeader,
@Nullable final List<Locale.LanguageRange> expectedLanguageRanges) {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
GrpcClientConnectionManager.handleWebSocketHandshakeComplete(embeddedChannel,
InetAddresses.forString("127.0.0.1"),
null,
acceptLanguageHeader);
assertEquals(expectedLanguageRanges,
embeddedChannel.attr(GrpcClientConnectionManager.ACCEPT_LANGUAGE_ATTRIBUTE_KEY).get());
assertEquals(expectedRequestAttributes,
embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY).get());
}
private static List<Arguments> handleWebSocketHandshakeCompleteAcceptLanguage() {
private static List<Arguments> handleHandshakeCompleteRequestAttributes() {
final InetAddress preferredRemoteAddress = InetAddresses.forString("192.168.1.1");
return List.of(
// Parseable list
Arguments.of("ja,en;q=0.4", Locale.LanguageRange.parse("ja,en;q=0.4")),
Arguments.argumentSet("Null User-Agent and Accept-Language headers",
preferredRemoteAddress, null, null,
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList())),
// Unparsable list
Arguments.of("This is not a valid language preference list", null),
Arguments.argumentSet("Recognized User-Agent and null Accept-Language header",
preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", null,
new RequestAttributes(preferredRemoteAddress, "Signal-Desktop/1.2.3 Linux", Collections.emptyList())),
// Missing list
Arguments.of(null, null)
Arguments.argumentSet("Unparsable User-Agent and null Accept-Language header",
preferredRemoteAddress, "Not a valid user-agent string", null,
new RequestAttributes(preferredRemoteAddress, "Not a valid user-agent string", Collections.emptyList())),
Arguments.argumentSet("Null User-Agent and parsable Accept-Language header",
preferredRemoteAddress, null, "ja,en;q=0.4",
new RequestAttributes(preferredRemoteAddress, null, Locale.LanguageRange.parse("ja,en;q=0.4"))),
Arguments.argumentSet("Null User-Agent and unparsable Accept-Language header",
preferredRemoteAddress, null, "This is not a valid language preference list",
new RequestAttributes(preferredRemoteAddress, null, Collections.emptyList()))
);
}
@Test
void handleConnectionEstablishedAuthenticated() throws InterruptedException {
void handleConnectionEstablishedAuthenticated() throws InterruptedException, ChannelNotFoundException {
final AuthenticatedDevice authenticatedDevice = new AuthenticatedDevice(UUID.randomUUID(), Device.PRIMARY_ID);
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.of(authenticatedDevice));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertEquals(List.of(remoteChannel), grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
assertNull(grpcClientConnectionManager.getRemoteChannelsByAuthenticatedDevice(authenticatedDevice));
}
@Test
void handleConnectionEstablishedAnonymous() throws InterruptedException {
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
void handleConnectionEstablishedAnonymous() throws InterruptedException, ChannelNotFoundException {
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
grpcClientConnectionManager.handleConnectionEstablished(localChannel, remoteChannel, Optional.empty());
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertEquals(remoteChannel, grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
remoteChannel.close().await();
assertNull(grpcClientConnectionManager.getRemoteChannelByLocalAddress(localChannel.localAddress()));
assertThrows(ChannelNotFoundException.class,
() -> grpcClientConnectionManager.getRemoteChannel(localChannel.localAddress()));
}
}

View File

@@ -523,10 +523,7 @@ class NoiseWebSocketTunnelServerIntegrationTest extends AbstractLeakDetectionTes
assertEquals(remoteAddress, response.getRemoteAddress());
assertEquals(List.of(acceptLanguage), response.getAcceptableLanguagesList());
assertEquals("DESKTOP", response.getUserAgent().getPlatform());
assertEquals("1.2.3", response.getUserAgent().getVersion());
assertEquals("Linux", response.getUserAgent().getAdditionalSpecifiers());
assertEquals(userAgent, response.getUserAgent());
} finally {
channel.shutdown();
}

View File

@@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.argumentSet;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
@@ -16,6 +17,7 @@ import io.netty.channel.local.LocalAddress;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@@ -31,6 +33,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
@@ -134,8 +137,13 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
embeddedChannel.setRemoteAddress(remoteAddress);
embeddedChannel.pipeline().fireUserEventTriggered(handshakeCompleteEvent);
assertEquals(expectedRemoteAddress,
embeddedChannel.attr(GrpcClientConnectionManager.REMOTE_ADDRESS_ATTRIBUTE_KEY).get());
Optional.ofNullable(embeddedChannel.attr(GrpcClientConnectionManager.REQUEST_ATTRIBUTES_KEY))
.map(Attribute::get)
.map(RequestAttributes::remoteAddress)
.orElse(null));
}
private static List<Arguments> getRemoteAddress() {
@@ -144,53 +152,53 @@ class WebsocketHandshakeCompleteHandlerTest extends AbstractLeakDetectionTest {
final InetAddress proxyAddress = InetAddresses.forString("4.3.2.1");
return List.of(
// Recognized proxy, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
argumentSet("Recognized proxy, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
clientAddress),
// Recognized proxy, multiple forwarded-for addresses
Arguments.of(new DefaultHttpHeaders()
argumentSet("Recognized proxy, multiple forwarded-for addresses",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress() + "," + proxyAddress.getHostAddress()),
remoteAddress,
proxyAddress),
// No recognized proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
argumentSet("No recognized proxy header, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// No recognized proxy header, no forwarded-for address
Arguments.of(new DefaultHttpHeaders(),
argumentSet("No recognized proxy header, no forwarded-for address",
new DefaultHttpHeaders(),
remoteAddress,
remoteAddress.getAddress()),
// Incorrect proxy header, single forwarded-for address
Arguments.of(new DefaultHttpHeaders()
argumentSet("Incorrect proxy header, single forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET + "-incorrect")
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, clientAddress.getHostAddress()),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, no forwarded-for address
Arguments.of(new DefaultHttpHeaders()
argumentSet("Recognized proxy, no forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
remoteAddress,
remoteAddress.getAddress()),
// Recognized proxy, bogus forwarded-for address
Arguments.of(new DefaultHttpHeaders()
argumentSet("Recognized proxy, bogus forwarded-for address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET)
.add(WebsocketHandshakeCompleteHandler.FORWARDED_FOR_HEADER, "not a valid address"),
remoteAddress,
null),
// No forwarded-for address, non-InetSocketAddress remote address
Arguments.of(new DefaultHttpHeaders()
argumentSet("No forwarded-for address, non-InetSocketAddress remote address",
new DefaultHttpHeaders()
.add(WebsocketHandshakeCompleteHandler.RECOGNIZED_PROXY_SECRET_HEADER, RECOGNIZED_PROXY_SECRET),
new LocalAddress("local-address"),
null)