multisend cleanup

This commit is contained in:
Jonathan Klabunde Tomer
2023-12-07 12:23:02 -08:00
committed by GitHub
parent 1fb88271e5
commit 4efda89358
3 changed files with 385 additions and 341 deletions

View File

@@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.collection.IsEmptyCollection.empty;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -29,6 +30,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
@@ -42,21 +44,24 @@ import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
@@ -73,8 +78,11 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsSources;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@@ -92,8 +100,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
@@ -139,6 +145,7 @@ class MessageControllerTest {
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
@@ -149,6 +156,11 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222;
private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333;
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@@ -192,13 +204,13 @@ class MessageControllerTest {
final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
);
final List<Device> multiDeviceList = List.of(
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
);
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
@@ -211,6 +223,8 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
mock(DynamicInboundMessageByteLimitConfiguration.class);
@@ -942,25 +956,21 @@ class MessageControllerTest {
);
}
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
long b = x & 0x7f;
x = x >>> 7;
if (x != 0) b |= 0x80;
bb.put((byte)b);
} while (x != 0);
private record Recipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
}
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(r.uuid().toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
}
writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
bb.put(r.deviceId()); // device id (1 byte)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
@@ -973,8 +983,15 @@ class MessageControllerTest {
// first write the header
bb.put(explicitIdentifiers
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
// count varint
int nRecip = recipients.size();
while (nRecip > 127) {
bb.put((byte) (nRecip & 0x7F | 0x80));
nRecip = nRecip >> 7;
}
bb.put((byte)(nRecip & 0x7F));
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
@@ -988,23 +1005,65 @@ class MessageControllerTest {
return new ByteArrayInputStream(buffer, 0, bb.position());
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
@Test
void testManyRecipientMessage() throws Exception {
final int nRecipients = 999;
final int devicesPerRecipient = 5;
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final List<Recipient> recipients = new ArrayList<>();
final List<Recipient> recipients;
if (recipientUUID == MULTI_DEVICE_UUID) {
recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
);
} else {
recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
for (int i = 0; i < nRecipients; i++) {
final List<Device> devices =
IntStream.range(1, devicesPerRecipient + 1)
.mapToObj(
d -> generateTestDevice(
(byte) d, 100 + d, 10 * d, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(),
System.currentTimeMillis()))
.collect(Collectors.toList());
final UUID aci = new UUID(0L, (long) i);
final UUID pni = new UUID(1L, (long) i);
final String e164 = String.format("+1408555%04d", i);
final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci))).thenReturn(Optional.of(account));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni))).thenReturn(Optional.of(account));
devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48])));
}
byte[] buffer = new byte[1048576];
InputStream stream = initializeMultiPayload(recipients, buffer, true);
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
final Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.put(entity);
assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200)));
verify(messageSender, times(nRecipients * devicesPerRecipient)).sendMessage(any(), any(), any(), eq(true));
}
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations
private void testMultiRecipientMessage(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
boolean authorize,
boolean isStory,
boolean urgent,
boolean explicitIdentifier,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
destinations.forEach(
(serviceIdentifier, deviceToRegistrationId) ->
deviceToRegistrationId.forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
// set up the entity to use in our PUT request
@@ -1023,124 +1082,160 @@ class MessageControllerTest {
// add access header if needed
if (authorize) {
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
if (authorize) {
ArgumentCaptor<Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class);
verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean());
assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent());
}
// We have a 2x2x2 grid of possible situations based on:
// - recipient enabled stories?
// - sender is authorized?
// - message is a story?
//
// (urgent is not included in the grid because it has no effect
// on any of the other settings.)
if (recipientUUID == MULTI_DEVICE_UUID) {
// This is the case where the recipient has enabled stories.
if(isStory) {
// We are sending a story, so we ignore access checks and expect this
// to go out to both the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// We are not sending a story, so we need to do access checks.
if (authorize) {
// When authorized we send a message to the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// When forbidden, we return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} else {
// This is the case where the recipient has not enabled stories.
if (isStory) {
// We are sending a story, so we ignore access checks.
// this recipient has one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// We are not sending a story so check access.
if (authorize) {
// If allowed, send a message to the recipient's one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// If forbidden, return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
exactly(expectedMessagesSent))
.sendMessage(
any(),
any(),
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
}
// Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
);
@SafeVarargs
private static <K, V> Map<K, V> submap(Map<K, V> map, K... keys) {
return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get));
}
@Test
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
return
Map.of(
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI),
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
);
}
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
new byte[48]),
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
private record MultiRecipientMessageTestCase(
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
boolean authenticated,
boolean story,
int expectedStatus,
int expectedSentMessages) {
}
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1700000000000L)
.queryParam("story", true)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
MultiRecipientMessageProvider.MEDIA_TYPE));
@CartesianTest
@CartesianTest.MethodFactory("testMultiRecipientMessageNoPni")
void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
}
checkGoodMultiRecipientResponse(response, 1);
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap(
targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new AciServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true) // urgent
.argumentsForNextParameter(false, true); // explicitIdentifiers
}
@CartesianTest
@CartesianTest.MethodFactory("testMultiRecipientMessagePni")
void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception {
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages());
}
private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap(
targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new PniServiceIdentifier(NONEXISTENT_UUID));
final boolean auth = true;
final boolean unauth = false;
final boolean story = true;
final boolean notStory = false;
return ArgumentSets
.argumentsForFirstParameter(
new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0),
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0),
new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3),
new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1),
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2),
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3),
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0))
.argumentsForNextParameter(false, true); // urgent
}
@ParameterizedTest
@@ -1148,7 +1243,7 @@ class MessageControllerTest {
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources
@@ -1346,12 +1441,12 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
throws NotPushRegisteredException, InterruptedException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
@@ -1384,8 +1479,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
@@ -1393,14 +1488,6 @@ class MessageControllerTest {
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
}
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.uuids404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
@@ -1433,64 +1520,4 @@ class MessageControllerTest {
return builder.build();
}
private static Recipient genRecipient(Random rng) {
UUID u1 = UUID.randomUUID(); // non-null
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
ByteBuffer bb = ByteBuffer.wrap(bytes);
writePayloadDeviceId(bb, expected);
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
long got = MultiRecipientMessageProvider.readVarint(stream);
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
}
@Test
void testVarintPayload() throws Exception {
Random rng = new Random();
byte[] bytes = new byte[12];
// some static test cases
for (byte i = 1; i <= 10; i++) {
roundTripVarint(i, bytes);
}
roundTripVarint(Byte.MAX_VALUE, bytes);
for (int i = 0; i < 1000; i++) {
// we need to ensure positive device IDs
byte start = (byte) rng.nextInt(128);
if (start == 0L) {
start = 1;
}
// run the test for this case
roundTripVarint(start, bytes);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
Random rng = new java.util.Random();
List<Recipient> expected = new LinkedList<>();
for(int i = 0; i < 100; i++) {
expected.add(genRecipient(rng));
}
byte[] buffer = new byte[100 + expected.size() * 100];
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
// the provider ignores the headers, java reflection, etc. so we don't use those here.
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
List<Recipient> got = Arrays.asList(res.recipients());
assertEquals(expected, got);
}
}

View File

@@ -10,10 +10,7 @@ import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
import static org.mockito.internal.invocation.InvocationMarker.markVerified;
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
import static org.mockito.internal.invocation.InvocationsFinder.findInvocations;
import java.time.Duration;
import java.util.List;
@@ -169,10 +166,17 @@ public final class MockUtils {
* this method
*/
public static VerificationMode exactly() {
return exactly(1);
}
/**
* a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that
* there are exactly N invocations of this method, and all of them match the given specification
*/
public static VerificationMode exactly(int wantedCount) {
return data -> {
MatchableInvocation target = data.getTarget();
final List<Invocation> allInvocations = data.getAllInvocations();
List<Invocation> chunk = findInvocations(allInvocations, target);
List<Invocation> otherInvocations = allInvocations.stream()
.filter(target::hasSameMethod)
.filter(Predicate.not(target::matches))
@@ -182,10 +186,7 @@ public final class MockUtils {
Invocation unverified = findFirstUnverified(otherInvocations);
throw noMoreInteractionsWanted(unverified, (List) allInvocations);
}
if (chunk.isEmpty()) {
throw wantedButNotInvoked(target);
}
markVerified(chunk.get(0), target);
Mockito.times(wantedCount).verify(data);
};
}