Check story rate limits in parallel

This commit is contained in:
Jon Chambers
2023-12-01 16:01:14 -05:00
committed by Jon Chambers
parent e9708b9259
commit 417d99a17e
2 changed files with 55 additions and 3 deletions

View File

@@ -43,6 +43,7 @@ import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
@@ -243,6 +244,8 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
when(rateLimiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null));
}
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
@@ -1148,6 +1151,7 @@ class MessageControllerTest {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
}
@SuppressWarnings("unused")
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
@@ -1449,7 +1453,7 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
throws NotPushRegisteredException, InterruptedException {
throws NotPushRegisteredException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
@@ -1490,6 +1494,37 @@ class MessageControllerTest {
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
@Test
void sendMultiRecipientMessageStoryRateLimited() {
final List<Recipient> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true)));
try (final Response response = invocationBuilder.put(entity)) {
assertEquals(413, response.getStatus());
}
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());