Use enriched gRPC status errors

This commit is contained in:
Ravi Khadiwala
2025-12-22 11:56:35 -06:00
committed by ravi-signal
parent 77eaec0150
commit a1b1d051f5
29 changed files with 989 additions and 318 deletions

View File

@@ -53,7 +53,7 @@ class ProhibitAuthenticationInterceptorTest {
final StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
() -> client.echo(EchoRequest.getDefaultInstance()));
assertEquals(Status.Code.UNAUTHENTICATED, e.getStatus().getCode());
assertEquals(Status.Code.INVALID_ARGUMENT, e.getStatus().getCode());
}
@Test

View File

@@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.filters;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
@@ -17,12 +18,14 @@ import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders;
import com.google.common.net.InetAddresses;
import com.google.protobuf.ByteString;
import com.google.rpc.ErrorInfo;
import com.vdurmont.semver4j.Semver;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.protobuf.StatusProto;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
@@ -40,9 +43,9 @@ import org.signal.chat.rpc.EchoServiceGrpc;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl;
import org.whispersystems.textsecuregcm.grpc.GrpcExceptions;
import org.whispersystems.textsecuregcm.grpc.MockRequestAttributesInterceptor;
import org.whispersystems.textsecuregcm.grpc.RequestAttributes;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@@ -153,7 +156,14 @@ class RemoteDeprecationFilterTest {
final StatusRuntimeException e = assertThrows(
StatusRuntimeException.class,
() -> client.echo(req));
assertEquals(StatusConstants.UPGRADE_NEEDED_STATUS.toString(), e.getStatus().toString());
final com.google.rpc.Status status = StatusProto.fromThrowable(e);
final ErrorInfo errorInfo = assertDoesNotThrow(() -> status.getDetailsList().stream()
.filter(any -> any.is(ErrorInfo.class)).findFirst()
.orElseThrow(() -> new AssertionError("No error info found"))
.unpack(ErrorInfo.class));
assertEquals(GrpcExceptions.DOMAIN, errorInfo.getDomain());
assertEquals(io.grpc.Status.Code.INVALID_ARGUMENT.value(), status.getCode());
assertEquals("UPGRADE_REQUIRED", errorInfo.getReason());
} else {
assertEquals("cluck cluck, i'm a parrot", client.echo(req).getPayload().toStringUtf8());
}

View File

@@ -29,8 +29,11 @@ import org.mockito.Mock;
import org.signal.chat.account.AccountsAnonymousGrpc;
import org.signal.chat.account.CheckAccountExistenceRequest;
import org.signal.chat.account.LookupUsernameHashRequest;
import org.signal.chat.account.LookupUsernameHashResponse;
import org.signal.chat.account.LookupUsernameLinkRequest;
import org.signal.chat.account.LookupUsernameLinkResponse;
import org.signal.chat.common.IdentityType;
import org.signal.chat.errors.NotFound;
import org.signal.chat.common.ServiceIdentifier;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
@@ -151,8 +154,8 @@ class AccountsAnonymousGrpcServiceTest extends
.getServiceIdentifier());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.NOT_FOUND,
() -> unauthenticatedServiceStub().lookupUsernameHash(LookupUsernameHashRequest.newBuilder()
assertEquals(LookupUsernameHashResponse.newBuilder().setNotFound(NotFound.getDefaultInstance()).build(),
unauthenticatedServiceStub().lookupUsernameHash(LookupUsernameHashRequest.newBuilder()
.setUsernameHash(ByteString.copyFrom(new byte[AccountController.USERNAME_HASH_LENGTH]))
.build()));
}
@@ -217,15 +220,16 @@ class AccountsAnonymousGrpcServiceTest extends
when(account.getEncryptedUsername()).thenReturn(Optional.empty());
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.NOT_FOUND,
() -> unauthenticatedServiceStub().lookupUsernameLink(LookupUsernameLinkRequest.newBuilder()
final LookupUsernameLinkResponse notFoundResponse = LookupUsernameLinkResponse.newBuilder()
.setNotFound(NotFound.getDefaultInstance())
.build();
assertEquals(notFoundResponse,
unauthenticatedServiceStub().lookupUsernameLink(LookupUsernameLinkRequest.newBuilder()
.setUsernameLinkHandle(UUIDUtil.toByteString(linkHandle))
.build()));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.NOT_FOUND,
() -> unauthenticatedServiceStub().lookupUsernameLink(LookupUsernameLinkRequest.newBuilder()
assertEquals(notFoundResponse,
unauthenticatedServiceStub().lookupUsernameLink(LookupUsernameLinkRequest.newBuilder()
.setUsernameLinkHandle(UUIDUtil.toByteString(UUID.randomUUID()))
.build()));
}

View File

@@ -47,8 +47,6 @@ import org.signal.chat.account.DeleteUsernameHashRequest;
import org.signal.chat.account.DeleteUsernameLinkRequest;
import org.signal.chat.account.GetAccountIdentityRequest;
import org.signal.chat.account.GetAccountIdentityResponse;
import org.signal.chat.account.ReserveUsernameHashError;
import org.signal.chat.account.ReserveUsernameHashErrorType;
import org.signal.chat.account.ReserveUsernameHashRequest;
import org.signal.chat.account.ReserveUsernameHashResponse;
import org.signal.chat.account.SetDiscoverableByPhoneNumberRequest;
@@ -57,7 +55,9 @@ import org.signal.chat.account.SetRegistrationLockResponse;
import org.signal.chat.account.SetRegistrationRecoveryPasswordRequest;
import org.signal.chat.account.SetUsernameLinkRequest;
import org.signal.chat.account.SetUsernameLinkResponse;
import org.signal.chat.account.UsernameNotAvailable;
import org.signal.chat.common.AccountIdentifiers;
import org.signal.chat.errors.FailedPrecondition;
import org.signal.libsignal.usernames.BaseUsernameException;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
@@ -173,7 +173,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, "BAD_AUTHENTICATION",
() -> authenticatedServiceStub().deleteAccount(DeleteAccountRequest.newBuilder().build()));
verify(accountsManager, never()).delete(any(), any());
@@ -217,7 +217,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, "BAD_AUTHENTICATION",
() -> authenticatedServiceStub().setRegistrationLock(SetRegistrationLockRequest.newBuilder()
.build()));
@@ -242,7 +242,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT,
() -> authenticatedServiceStub().clearRegistrationLock(ClearRegistrationLockRequest.newBuilder().build()));
verify(accountsManager, never()).updateAsync(any(), any());
@@ -288,9 +288,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
.thenReturn(CompletableFuture.failedFuture(new UsernameHashNotAvailableException()));
final ReserveUsernameHashResponse expectedResponse = ReserveUsernameHashResponse.newBuilder()
.setError(ReserveUsernameHashError.newBuilder()
.setErrorType(ReserveUsernameHashErrorType.RESERVE_USERNAME_HASH_ERROR_TYPE_NO_HASHES_AVAILABLE)
.build())
.setUsernameNotAvailable(UsernameNotAvailable.getDefaultInstance())
.build();
assertEquals(expectedResponse,
@@ -379,8 +377,10 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
});
final ConfirmUsernameHashResponse expectedResponse = ConfirmUsernameHashResponse.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameHash))
.setUsernameLinkHandle(UUIDUtil.toByteString(linkHandle))
.setConfirmedUsernameHash(ConfirmUsernameHashResponse.ConfirmedUsernameHash.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameHash))
.setUsernameLinkHandle(UUIDUtil.toByteString(linkHandle))
.build())
.build();
assertEquals(expectedResponse,
@@ -393,7 +393,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@ParameterizedTest
@MethodSource
void confirmUsernameHashConfirmationException(final Exception confirmationException, final Status expectedStatus) {
void confirmUsernameHashConfirmationException(final Exception confirmationException, final ConfirmUsernameHashResponse expectedResponse) {
final byte[] usernameHash = TestRandomUtil.nextBytes(AccountController.USERNAME_HASH_LENGTH);
final byte[] usernameCiphertext = TestRandomUtil.nextBytes(32);
@@ -408,19 +408,26 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
when(accountsManager.confirmReservedUsernameHash(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(confirmationException));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(expectedStatus,
() -> authenticatedServiceStub().confirmUsernameHash(ConfirmUsernameHashRequest.newBuilder()
final ConfirmUsernameHashResponse actualResponse = authenticatedServiceStub()
.confirmUsernameHash(ConfirmUsernameHashRequest.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameHash))
.setUsernameCiphertext(ByteString.copyFrom(usernameCiphertext))
.setZkProof(ByteString.copyFrom(zkProof))
.build()));
.build());
assertEquals(expectedResponse, actualResponse);
}
private static Stream<Arguments> confirmUsernameHashConfirmationException() {
return Stream.of(
Arguments.of(new UsernameHashNotAvailableException(), Status.NOT_FOUND),
Arguments.of(new UsernameReservationNotFoundException(), Status.FAILED_PRECONDITION)
Arguments.of( new UsernameHashNotAvailableException(),
ConfirmUsernameHashResponse.newBuilder()
.setUsernameNotAvailable(UsernameNotAvailable.getDefaultInstance())
.build()),
Arguments.of(new UsernameReservationNotFoundException(),
ConfirmUsernameHashResponse.newBuilder()
.setReservationNotFound(FailedPrecondition.getDefaultInstance())
.build())
);
}
@@ -546,9 +553,11 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
final byte[] usernameCiphertext = TestRandomUtil.nextBytes(EncryptedUsername.MAX_SIZE);
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.FAILED_PRECONDITION,
() -> authenticatedServiceStub().setUsernameLink(SetUsernameLinkRequest.newBuilder()
assertEquals(
SetUsernameLinkResponse.newBuilder()
.setNoUsernameSet(FailedPrecondition.getDefaultInstance())
.build(),
authenticatedServiceStub().setUsernameLink(SetUsernameLinkRequest.newBuilder()
.setUsernameCiphertext(ByteString.copyFrom(usernameCiphertext))
.build()));
}

View File

@@ -17,6 +17,7 @@ import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Clock;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.Iterator;
@@ -295,12 +296,10 @@ class BackupsAnonymousGrpcServiceTest extends
assertThat(uploadForm.getSignedUploadLocation()).isEqualTo("example.org");
// rate limit
Duration duration = Duration.ofSeconds(10);
when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> unauthenticatedServiceStub().getUploadForm(request))
.extracting(StatusRuntimeException::getStatus)
.isEqualTo(Status.RESOURCE_EXHAUSTED);
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(duration)));
GrpcTestUtils.assertRateLimitExceeded(duration, () -> unauthenticatedServiceStub().getUploadForm(request));
}
static Stream<Arguments> messagesUploadForm() {

View File

@@ -0,0 +1,157 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.Any;
import com.google.rpc.ErrorInfo;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.protobuf.StatusProto;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.ReactorEchoServiceGrpc;
import org.signal.chat.rpc.SimpleEchoServiceGrpc;
import reactor.core.publisher.Mono;
class ErrorMappingInterceptorTest {
private Server server;
private ManagedChannel channel;
@BeforeEach
void setUp() throws Exception {
channel = InProcessChannelBuilder.forName("ErrorMappingInterceptorTest")
.directExecutor()
.build();
}
@AfterEach
void tearDown() throws Exception {
server.shutdownNow();
channel.shutdownNow();
server.awaitTermination(1, TimeUnit.SECONDS);
channel.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
public void includeDetailsSimpleGrpc() throws Exception {
final StatusRuntimeException e = StatusProto.toStatusRuntimeException(com.google.rpc.Status.newBuilder()
.setCode(Status.Code.INVALID_ARGUMENT.value())
.addDetails(Any.pack(ErrorInfo.newBuilder()
.setDomain("test")
.setReason("TEST")
.build()))
.build());
server = InProcessServerBuilder.forName("ErrorMappingInterceptorTest")
.directExecutor()
.addService(new SimpleEchoServiceErrorImpl(e))
.intercept(new ErrorMappingInterceptor())
.build()
.start();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, "TEST", () ->
client.echo(EchoRequest.getDefaultInstance()));
}
@Test
public void includeDetailsReactiveGrpc() throws Exception {
final StatusRuntimeException e = StatusProto.toStatusRuntimeException(com.google.rpc.Status.newBuilder()
.setCode(Status.Code.INVALID_ARGUMENT.value())
.addDetails(Any.pack(ErrorInfo.newBuilder()
.setDomain("test")
.setReason("TEST")
.build()))
.build());
server = InProcessServerBuilder.forName("ErrorMappingInterceptorTest")
.directExecutor()
.addService(new ReactorEchoServiceErrorImpl(e))
.intercept(new ErrorMappingInterceptor())
.build()
.start();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, "TEST", () ->
client.echo(EchoRequest.getDefaultInstance()));
}
@Test
public void mapIOExceptionsReactive() throws Exception {
server = InProcessServerBuilder.forName("ErrorMappingInterceptorTest")
.directExecutor()
.addService(new ReactorEchoServiceErrorImpl(new IOException("test")))
.intercept(new ErrorMappingInterceptor())
.build()
.start();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, "UNAVAILABLE", () ->
client.echo(EchoRequest.getDefaultInstance()));
}
@Test
public void mapIOExceptionsSimple() throws Exception {
server = InProcessServerBuilder.forName("ErrorMappingInterceptorTest")
.directExecutor()
.addService(new SimpleEchoServiceErrorImpl(new UncheckedIOException(new IOException("test"))))
.intercept(new ErrorMappingInterceptor())
.build()
.start();
final EchoServiceGrpc.EchoServiceBlockingStub client = EchoServiceGrpc.newBlockingStub(channel);
GrpcTestUtils.assertStatusException(Status.UNAVAILABLE, "UNAVAILABLE", () ->
client.echo(EchoRequest.getDefaultInstance()));
}
static class ReactorEchoServiceErrorImpl extends ReactorEchoServiceGrpc.EchoServiceImplBase {
private final Exception exception;
ReactorEchoServiceErrorImpl(final Exception exception) {
this.exception = exception;
}
@Override
public Mono<EchoResponse> echo(final EchoRequest echoRequest) {
return Mono.error(exception);
}
@Override
public Throwable onErrorMap(Throwable throwable) {
return new IllegalArgumentException(throwable);
}
}
static class SimpleEchoServiceErrorImpl extends SimpleEchoServiceGrpc.EchoServiceImplBase {
private final RuntimeException exception;
SimpleEchoServiceErrorImpl(final RuntimeException exception) {
this.exception = exception;
}
@Override
public EchoResponse echo(final EchoRequest echoRequest) {
throw exception;
}
}
}

View File

@@ -122,14 +122,6 @@ public class ExternalServiceCredentialsGrpcServiceTest
);
}
@Test
public void testUnauthenticatedCall() throws Exception {
assertStatusUnauthenticated(() -> unauthenticatedServiceStub().getExternalServiceCredentials(
GetExternalServiceCredentialsRequest.newBuilder()
.setExternalService(ExternalServiceType.EXTERNAL_SERVICE_TYPE_DIRECTORY)
.build()));
}
/**
* `ExternalServiceDefinitions` enum is supposed to have entries for all values in `ExternalServiceType`,
* except for the `EXTERNAL_SERVICE_TYPE_UNSPECIFIED` and `UNRECOGNIZED`.

View File

@@ -5,16 +5,24 @@
package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.verifyNoInteractions;
import com.google.protobuf.Any;
import com.google.protobuf.Message;
import com.google.rpc.ErrorInfo;
import com.google.rpc.RetryInfo;
import io.grpc.BindableService;
import io.grpc.ServerInterceptors;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import io.grpc.protobuf.StatusProto;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.function.Executable;
import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
@@ -51,6 +59,12 @@ public final class GrpcTestUtils {
assertEquals(expected.getCode(), exception.getStatus().getCode());
}
public static void assertStatusException(final Status expected, final String expectedReason, final Executable serviceCall) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(expected.getCode(), exception.getStatus().getCode());
assertEquals(expectedReason, extractErrorInfo(exception).getReason());
}
public static void assertStatusInvalidArgument(final Executable serviceCall) {
assertStatusException(Status.INVALID_ARGUMENT, serviceCall);
}
@@ -68,11 +82,31 @@ public final class GrpcTestUtils {
final Executable serviceCall,
final Object... mocksToCheckForNoInteraction) {
final StatusRuntimeException exception = Assertions.assertThrows(StatusRuntimeException.class, serviceCall);
assertEquals(Status.RESOURCE_EXHAUSTED, exception.getStatus());
assertEquals(Status.RESOURCE_EXHAUSTED.getCode(), exception.getStatus().getCode());
assertNotNull(exception.getTrailers());
assertEquals(expectedRetryAfter, exception.getTrailers().get(RateLimitExceededException.RETRY_AFTER_DURATION_KEY));
final ErrorInfo errorInfo = extractErrorInfo(exception);
final RetryInfo retryInfo = extractDetail(RetryInfo.class, exception);
final Duration actual = Duration.ofSeconds(retryInfo.getRetryDelay().getSeconds(), retryInfo.getRetryDelay().getNanos());
assertEquals(errorInfo.getDomain(), GrpcExceptions.DOMAIN);
assertEquals(errorInfo.getReason(), "RESOURCE_EXHAUSTED");
assertEquals(expectedRetryAfter, actual);
for (final Object mock: mocksToCheckForNoInteraction) {
verifyNoInteractions(mock);
}
}
public static ErrorInfo extractErrorInfo(final StatusRuntimeException exception) {
return extractDetail(ErrorInfo.class, exception);
}
public static <T extends Message> T extractDetail(final Class<T> detailCls, final StatusRuntimeException exception) {
final com.google.rpc.Status status = StatusProto.fromThrowable(exception);
return assertDoesNotThrow(() -> status.getDetailsList().stream()
.filter(any -> any.is(detailCls)).findFirst()
.orElseThrow(() -> new AssertionError("No error info found"))
.unpack(detailCls));
}
}

View File

@@ -10,12 +10,17 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.net.InetAddresses;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import com.google.rpc.ErrorInfo;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.protobuf.StatusProto;
import io.grpc.stub.BlockingClientCall;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Meter;
@@ -24,34 +29,48 @@ import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.util.List;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.chat.rpc.EchoRequest;
import org.signal.chat.rpc.EchoResponse;
import org.signal.chat.rpc.EchoServiceGrpc;
import org.signal.chat.rpc.SimpleTagTestServiceGrpc;
import org.signal.chat.rpc.TagResponse;
import org.signal.chat.rpc.TagTestServiceGrpc;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import reactor.adapter.JdkFlowAdapter;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
public class MetricServerInterceptorTest {
private static String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)";
private static final String USER_AGENT = "Signal-Android/4.53.7 (Android 8.1; libsignal)";
private Server server;
private ManagedChannel channel;
private SimpleMeterRegistry simpleMeterRegistry;
private ClientReleaseManager clientReleaseManager;
private Supplier<TagResponse> tagResponseSupplier;
@BeforeEach
void setUp() throws Exception {
simpleMeterRegistry = new SimpleMeterRegistry();
clientReleaseManager = mock(ClientReleaseManager.class);
tagResponseSupplier = mock(Supplier.class);
final MockRequestAttributesInterceptor mockRequestAttributesInterceptor = new MockRequestAttributesInterceptor();
mockRequestAttributesInterceptor.setRequestAttributes(
new RequestAttributes(InetAddresses.forString("127.0.0.1"), USER_AGENT, null));
@@ -59,9 +78,9 @@ public class MetricServerInterceptorTest {
server = InProcessServerBuilder.forName("MetricServerInterceptorTest")
.directExecutor()
.addService(new EchoServiceImpl())
.addService(new TagTestServiceImpl(tagResponseSupplier))
.intercept(new MetricServerInterceptor(simpleMeterRegistry, clientReleaseManager))
.intercept(mockRequestAttributesInterceptor)
.intercept(mockRequestAttributesInterceptor)
.build()
.start();
@@ -111,7 +130,7 @@ public class MetricServerInterceptorTest {
}
@Test
void streaming() throws StatusException, InterruptedException, TimeoutException {
void streaming() throws StatusException, InterruptedException {
final EchoServiceGrpc.EchoServiceBlockingV2Stub client = EchoServiceGrpc.newBlockingV2Stub(channel);
final BlockingClientCall<EchoRequest, EchoResponse> echoStream = client.echoStream();
echoStream.write(EchoRequest.newBuilder().setPayload(ByteString.copyFromUtf8("1")).build());
@@ -152,6 +171,99 @@ public class MetricServerInterceptorTest {
assertThat(expectedClientVersion).isEqualTo(actualClientVersion);
}
static Stream<Arguments> testUnaryOkResponseReason() {
return Stream.of(
Arguments.argumentSet("Default reason", TagResponse.newBuilder().build(), "success"),
Arguments.argumentSet("No reason", TagResponse.newBuilder().setNoReason(true).build(), "success"),
Arguments.argumentSet("Explicitly set reason", TagResponse.newBuilder().setReason1(true).build(), "reason_1"),
Arguments.argumentSet("Nested reason", TagResponse.newBuilder().setNestedReason(TagResponse.NestedReason.newBuilder().setReason(true)).build(), "nested_reason"));
}
@ParameterizedTest
@MethodSource
void testUnaryOkResponseReason(TagResponse response, String expectedReason) throws InterruptedException {
final TagTestServiceGrpc.TagTestServiceBlockingStub tagTestServiceBlockingStub =
TagTestServiceGrpc.newBlockingStub(channel);
when(tagResponseSupplier.get()).thenReturn(response);
tagTestServiceBlockingStub.tagEndpoint(Empty.getDefaultInstance());
final Counter rpcCount = find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME);
assertThat(rpcCount.count()).isCloseTo(1.0, offset(0.01));
assertThat(rpcCount.getId().getTag("statusCode")).isEqualTo("OK");
assertThat(rpcCount.getId().getTag("reason")).isEqualTo(expectedReason);
}
@Test
public void testConflictingReasons() {
final TagTestServiceGrpc.TagTestServiceBlockingStub tagTestServiceBlockingStub =
TagTestServiceGrpc.newBlockingStub(channel);
when(tagResponseSupplier.get())
.thenReturn(TagResponse.newBuilder().setReason1(true).setConflictingReason(true).build());
tagTestServiceBlockingStub.tagEndpoint(Empty.getDefaultInstance());
// We make no promises if proto fields that have reason tags are present on a message, but this tests for the sane
// behavior that at least one of these tags makes it into the metric.
assertThat(find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME).getId().getTag("reason"))
.isIn("duplicate_reason", "reason_1");
}
@CartesianTest
public void testStatusErrorResponseReason(
@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"OK"}) Status.Code statusCode,
@CartesianTest.Values(strings = {"test", "", "null"}) String reasonParam) {
final String reason, expectedReasonTag;
if (reasonParam.equals("null")) {
reason = null;
expectedReasonTag = MetricServerInterceptor.DEFAULT_ERROR_REASON;
} else {
reason = reasonParam;
expectedReasonTag = reasonParam;
}
final TagTestServiceGrpc.TagTestServiceBlockingStub tagTestServiceBlockingStub =
TagTestServiceGrpc.newBlockingStub(channel);
final com.google.rpc.Status.Builder builder = com.google.rpc.Status.newBuilder()
.setCode(statusCode.value())
.setMessage("test");
if (reason != null) {
builder.addDetails(Any.pack(ErrorInfo.newBuilder()
.setDomain("domain")
.setReason(reason)
.build()));
}
when(tagResponseSupplier.get()).thenThrow(StatusProto.toStatusRuntimeException(builder.build()));
GrpcTestUtils.assertStatusException(statusCode.toStatus(),
() -> tagTestServiceBlockingStub.tagEndpoint(Empty.getDefaultInstance()));
final Counter rpcCount = find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME);
assertThat(rpcCount.count()).isCloseTo(1.0, offset(0.01));
assertThat(rpcCount.getId().getTag("statusCode")).isEqualTo(statusCode.name());
assertThat(rpcCount.getId().getTag("reason")).isEqualTo(expectedReasonTag);
}
@Test
public void testStreamingResponseReason() {
final TagTestServiceGrpc.TagTestServiceBlockingStub tagTestServiceBlockingStub =
TagTestServiceGrpc.newBlockingStub(channel);
when(tagResponseSupplier.get())
.thenReturn(TagResponse.newBuilder().setReason1(true).build())
.thenReturn(TagResponse.newBuilder().setNoReason(true).build())
.thenReturn(null);
tagTestServiceBlockingStub.streamingTagEndpoint(Empty.getDefaultInstance()).forEachRemaining(_ -> {});
final Counter messageCounter = find(Counter.class, MetricServerInterceptor.RESPONSE_COUNTER_NAME);
assertThat(messageCounter.count()).isCloseTo(2.0, offset(0.01));
final Counter rpcCount = find(Counter.class, MetricServerInterceptor.RPC_COUNTER_NAME);
assertThat(rpcCount.count()).isCloseTo(1.0, offset(0.01));
assertThat(rpcCount.getId().getTag("statusCode")).isEqualTo("OK");
assertThat(rpcCount.getId().getTag("reason")).isEqualTo(MetricServerInterceptor.DEFAULT_SUCCESS_REASON);
}
private <T extends Meter> T find(Class<T> cls, final String name) {
final Meter meter = simpleMeterRegistry.getMeters().stream()
.filter(m -> m.getId().getName().equals(name))
@@ -162,4 +274,32 @@ public class MetricServerInterceptorTest {
}
throw new IllegalArgumentException("Meter " + name + " should be an instance of " + cls);
}
class TagTestServiceImpl extends SimpleTagTestServiceGrpc.TagTestServiceImplBase {
private Supplier<TagResponse> tagResponseSupplier;
TagTestServiceImpl(Supplier<TagResponse> tagResponseSupplier) {
this.tagResponseSupplier = tagResponseSupplier;
}
@Override
public TagResponse tagEndpoint(final Empty request) {
return tagResponseSupplier.get();
}
@Override
public Flow.Publisher<TagResponse> streamingTagEndpoint(com.google.protobuf.Empty request) {
return JdkFlowAdapter.publisherToFlowPublisher(Flux.<TagResponse>create(sink -> {
while (!sink.isCancelled()) {
TagResponse item = tagResponseSupplier.get();
if (item == null) {
sink.complete();
break;
}
sink.next(item);
}
})
.subscribeOn(Schedulers.boundedElastic()));
}
}
}

View File

@@ -63,10 +63,4 @@ class PaymentsGrpcServiceTest extends SimpleBaseGrpcTest<PaymentsGrpcService, Pa
assertStatusException(Status.UNAVAILABLE, () -> authenticatedServiceStub().getCurrencyConversions(
GetCurrencyConversionsRequest.newBuilder().build()));
}
@Test
public void testUnauthenticated() throws Exception {
assertStatusException(Status.UNAUTHENTICATED, () -> unauthenticatedServiceStub().getCurrencyConversions(
GetCurrencyConversionsRequest.newBuilder().build()));
}
}