diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyController.java index 420bb1cf7..1bea965b4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyController.java @@ -66,11 +66,15 @@ public class CallQualitySurveyController { try { submitCallQualitySurveyRequest = SubmitCallQualitySurveyRequest.parseFrom(surveyResponse); } catch (final InvalidProtocolBufferException e) { - throw new WebApplicationException(422); + throw new WebApplicationException("Invalid protobuf entity", 422); } final String remoteAddress = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - callQualitySurveyManager.submitCallQualitySurvey(submitCallQualitySurveyRequest, remoteAddress, userAgentString); + try { + callQualitySurveyManager.submitCallQualitySurvey(submitCallQualitySurveyRequest, remoteAddress, userAgentString); + } catch (final IllegalArgumentException e) { + throw new WebApplicationException(e.getMessage(), 422); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcService.java index cb6ecfc4f..f52d56107 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcService.java @@ -5,6 +5,8 @@ package org.whispersystems.textsecuregcm.grpc; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import org.signal.chat.calling.quality.SimpleCallQualityGrpc; import org.signal.chat.calling.quality.SubmitCallQualitySurveyRequest; import org.signal.chat.calling.quality.SubmitCallQualitySurveyResponse; @@ -32,9 +34,13 @@ public class CallQualitySurveyGrpcService extends SimpleCallQualityGrpc.CallQual rateLimiters.getSubmitCallQualitySurveyLimiter().validate(remoteAddress); - callQualitySurveyManager.submitCallQualitySurvey(request, - remoteAddress, - RequestAttributesUtil.getUserAgent().orElse(null)); + try { + callQualitySurveyManager.submitCallQualitySurvey(request, + remoteAddress, + RequestAttributesUtil.getUserAgent().orElse(null)); + } catch (final IllegalArgumentException e) { + throw Status.INVALID_ARGUMENT.withDescription(e.getMessage()).asRuntimeException(); + } return SubmitCallQualitySurveyResponse.getDefaultInstance(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManager.java index 5e1bba00c..827e969f1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManager.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.metrics; import com.google.cloud.pubsub.v1.PublisherInterface; +import com.google.common.annotations.VisibleForTesting; import com.google.pubsub.v1.PubsubMessage; import io.micrometer.core.instrument.Metrics; import java.time.Clock; @@ -147,4 +148,23 @@ public class CallQualitySurveyManager { .increment(); }); } + + @VisibleForTesting + static void validateRequest(final SubmitCallQualitySurveyRequest request) { + if (request.getStartTimestamp() == 0) { + throw new IllegalArgumentException("Start timestamp not specified"); + } + + if (request.getEndTimestamp() == 0) { + throw new IllegalArgumentException("End timestamp not specified"); + } + + if (StringUtils.isBlank(request.getCallType())) { + throw new IllegalArgumentException("Call type not specified"); + } + + if (StringUtils.isBlank(request.getCallEndReason())) { + throw new IllegalArgumentException("Call end reason not specified"); + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyControllerTest.java index 5980c3526..f3de06077 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallQualitySurveyControllerTest.java @@ -7,10 +7,12 @@ package org.whispersystems.textsecuregcm.controllers; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -22,6 +24,9 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.signal.chat.calling.quality.SubmitCallQualitySurveyRequest; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; @@ -29,6 +34,7 @@ import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; +import java.util.List; @ExtendWith(DropwizardExtensionsSupport.class) class CallQualitySurveyControllerTest { @@ -83,4 +89,21 @@ class CallQualitySurveyControllerTest { verify(CALL_QUALITY_SURVEY_MANAGER, never()).submitCallQualitySurvey(any(), any(), any()); } } + + @Test + void submitCallQualitySurveyInvalidArgument() { + final SubmitCallQualitySurveyRequest request = SubmitCallQualitySurveyRequest.getDefaultInstance(); + + doThrow(new IllegalArgumentException()) + .when(CALL_QUALITY_SURVEY_MANAGER).submitCallQualitySurvey(request, REMOTE_ADDRESS, USER_AGENT); + + try (final Response response = RESOURCE_EXTENSION.getJerseyTest() + .target("/v1/call_quality_survey") + .request() + .header("User-Agent", USER_AGENT) + .put(Entity.entity(request.toByteArray(), MediaType.APPLICATION_OCTET_STREAM_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcServiceTest.java index 4af14d07c..caf29bd20 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/CallQualitySurveyGrpcServiceTest.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.when; import com.google.common.net.InetAddresses; import java.time.Duration; +import io.grpc.Status; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -67,4 +68,16 @@ class CallQualitySurveyGrpcServiceTest extends SimpleBaseGrpcTest unauthenticatedServiceStub().submitCallQualitySurvey(SubmitCallQualitySurveyRequest.getDefaultInstance())); } + + @Test + void submitCallQualitySurveyInvalidArgument() { + final SubmitCallQualitySurveyRequest request = SubmitCallQualitySurveyRequest.getDefaultInstance(); + + doThrow(new IllegalArgumentException()) + .when(callQualitySurveyManager).submitCallQualitySurvey(request, REMOTE_ADDRESS, USER_AGENT); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, + () -> unauthenticatedServiceStub().submitCallQualitySurvey(request)); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManagerTest.java index 858aff442..ee5b6cac7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/CallQualitySurveyManagerTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -27,6 +28,11 @@ import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.platform.commons.util.StringUtils; import org.mockito.ArgumentCaptor; import org.signal.calling.survey.CallQualitySurveyResponsePubSubMessage; import org.signal.chat.calling.quality.SubmitCallQualitySurveyRequest; @@ -140,4 +146,36 @@ class CallQualitySurveyManagerTest { assertEquals(videoSendPacketLossFraction, callQualitySurveyResponsePubSubMessage.getVideoSendPacketLossFraction()); assertArrayEquals(telemetryBytes, callQualitySurveyResponsePubSubMessage.getCallTelemetry().toByteArray()); } + + @ParameterizedTest + @MethodSource + void validateRequest(final SubmitCallQualitySurveyRequest request, final boolean expectValid) { + final Executable validateRequest = () -> CallQualitySurveyManager.validateRequest(request); + + if (expectValid) { + assertDoesNotThrow(validateRequest); + } else { + final IllegalArgumentException illegalArgumentException = + assertThrows(IllegalArgumentException.class, validateRequest); + + assertTrue(StringUtils.isNotBlank(illegalArgumentException.getMessage())); + } + } + + private static List validateRequest() { + final SubmitCallQualitySurveyRequest validRequest = SubmitCallQualitySurveyRequest.newBuilder() + .setStartTimestamp(Instant.now().toEpochMilli()) + .setEndTimestamp(Instant.now().plusSeconds(60).toEpochMilli()) + .setCallType("test") + .setCallEndReason("test") + .build(); + + return List.of( + Arguments.argumentSet("Valid survey response", validRequest, true), + Arguments.argumentSet("No start timestamp", validRequest.toBuilder().clearStartTimestamp().build(), false), + Arguments.argumentSet("No end timestamp", validRequest.toBuilder().clearEndTimestamp().build(), false), + Arguments.argumentSet("No call type", validRequest.toBuilder().clearCallType().build(), false), + Arguments.argumentSet("No call end reason", validRequest.toBuilder().clearCallEndReason().build(), false) + ); + } }