mirror of
https://github.com/signalapp/Signal-Server
synced 2026-04-21 02:08:03 +01:00
@@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.controllers;
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
import static org.hamcrest.collection.IsEmptyCollection.empty;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
@@ -30,7 +29,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
|
||||
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
|
||||
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@@ -44,23 +42,21 @@ import java.io.ByteArrayInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.client.Invocation;
|
||||
@@ -77,11 +73,8 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.ArgumentsSources;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
@@ -99,6 +92,8 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
@@ -144,7 +139,6 @@ class MessageControllerTest {
|
||||
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
|
||||
private static final byte SINGLE_DEVICE_ID1 = 1;
|
||||
private static final int SINGLE_DEVICE_REG_ID1 = 111;
|
||||
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
|
||||
|
||||
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
|
||||
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
|
||||
@@ -155,11 +149,6 @@ class MessageControllerTest {
|
||||
private static final int MULTI_DEVICE_REG_ID1 = 222;
|
||||
private static final int MULTI_DEVICE_REG_ID2 = 333;
|
||||
private static final int MULTI_DEVICE_REG_ID3 = 444;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID1 = 2222;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID2 = 3333;
|
||||
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
|
||||
|
||||
private static final UUID NONEXISTENT_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
|
||||
|
||||
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
|
||||
|
||||
@@ -203,13 +192,13 @@ class MessageControllerTest {
|
||||
|
||||
|
||||
final List<Device> singleDeviceList = List.of(
|
||||
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
|
||||
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis())
|
||||
);
|
||||
|
||||
final List<Device> multiDeviceList = List.of(
|
||||
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
|
||||
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()),
|
||||
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
|
||||
);
|
||||
|
||||
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
|
||||
@@ -222,8 +211,6 @@ class MessageControllerTest {
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
|
||||
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
|
||||
|
||||
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
|
||||
mock(DynamicInboundMessageByteLimitConfiguration.class);
|
||||
@@ -935,21 +922,25 @@ class MessageControllerTest {
|
||||
);
|
||||
}
|
||||
|
||||
private record Recipient(ServiceIdentifier uuid,
|
||||
byte deviceId,
|
||||
int registrationId,
|
||||
byte[] perRecipientKeyMaterial) {
|
||||
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
|
||||
long x = deviceId;
|
||||
// write the device-id in the 7-bit varint format we use, least significant bytes first.
|
||||
do {
|
||||
long b = x & 0x7f;
|
||||
x = x >>> 7;
|
||||
if (x != 0) b |= 0x80;
|
||||
bb.put((byte)b);
|
||||
} while (x != 0);
|
||||
}
|
||||
|
||||
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r,
|
||||
final boolean useExplicitIdentifier) {
|
||||
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
|
||||
if (useExplicitIdentifier) {
|
||||
bb.put(r.uuid().toFixedWidthByteArray());
|
||||
} else {
|
||||
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
|
||||
}
|
||||
|
||||
bb.put(r.deviceId()); // device id (1 byte)
|
||||
writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
|
||||
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
|
||||
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
|
||||
}
|
||||
@@ -962,8 +953,8 @@ class MessageControllerTest {
|
||||
// first write the header
|
||||
bb.put(explicitIdentifiers
|
||||
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
|
||||
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
|
||||
bb.put((byte)recipients.size()); // count varint
|
||||
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
|
||||
bb.put((byte)recipients.size()); // count varint
|
||||
|
||||
Iterator<Recipient> it = recipients.iterator();
|
||||
while (it.hasNext()) {
|
||||
@@ -977,24 +968,23 @@ class MessageControllerTest {
|
||||
return new ByteArrayInputStream(buffer, 0, bb.position());
|
||||
}
|
||||
|
||||
// see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations
|
||||
private void testMultiRecipientMessage(
|
||||
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
|
||||
boolean authorize,
|
||||
boolean isStory,
|
||||
boolean urgent,
|
||||
boolean explicitIdentifier,
|
||||
int expectedStatus,
|
||||
int expectedMessagesSent) throws Exception {
|
||||
final List<Recipient> recipients = new ArrayList<>();
|
||||
destinations.forEach(
|
||||
(serviceIdentifier, deviceToRegistrationId) ->
|
||||
deviceToRegistrationId.forEach(
|
||||
(deviceId, registrationId) ->
|
||||
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
|
||||
|
||||
final List<Recipient> recipients;
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
|
||||
);
|
||||
} else {
|
||||
recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
|
||||
}
|
||||
|
||||
// initialize our binary payload and create an input stream
|
||||
byte[] buffer = new byte[2048];
|
||||
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
|
||||
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
|
||||
|
||||
// set up the entity to use in our PUT request
|
||||
@@ -1013,160 +1003,124 @@ class MessageControllerTest {
|
||||
|
||||
// add access header if needed
|
||||
if (authorize) {
|
||||
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
|
||||
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
|
||||
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
|
||||
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
|
||||
}
|
||||
|
||||
// make the PUT request
|
||||
Response response = bldr.put(entity);
|
||||
|
||||
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
|
||||
verify(messageSender,
|
||||
exactly(expectedMessagesSent))
|
||||
.sendMessage(
|
||||
any(),
|
||||
any(),
|
||||
argThat(env -> env.getUrgent() == urgent && !env.hasSourceUuid() && !env.hasSourceDevice()),
|
||||
anyBoolean());
|
||||
if (expectedStatus == 200) {
|
||||
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
|
||||
assertThat(smrmr.uuids404(), is(empty()));
|
||||
if (authorize) {
|
||||
ArgumentCaptor<Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(Envelope.class);
|
||||
verify(messageSender, atLeastOnce()).sendMessage(any(), any(), envelopeArgumentCaptor.capture(), anyBoolean());
|
||||
assertEquals(urgent, envelopeArgumentCaptor.getValue().getUrgent());
|
||||
}
|
||||
|
||||
// We have a 2x2x2 grid of possible situations based on:
|
||||
// - recipient enabled stories?
|
||||
// - sender is authorized?
|
||||
// - message is a story?
|
||||
//
|
||||
// (urgent is not included in the grid because it has no effect
|
||||
// on any of the other settings.)
|
||||
|
||||
if (recipientUUID == MULTI_DEVICE_UUID) {
|
||||
// This is the case where the recipient has enabled stories.
|
||||
if(isStory) {
|
||||
// We are sending a story, so we ignore access checks and expect this
|
||||
// to go out to both the recipient's devices.
|
||||
checkGoodMultiRecipientResponse(response, 2);
|
||||
} else {
|
||||
// We are not sending a story, so we need to do access checks.
|
||||
if (authorize) {
|
||||
// When authorized we send a message to the recipient's devices.
|
||||
checkGoodMultiRecipientResponse(response, 2);
|
||||
} else {
|
||||
// When forbidden, we return a 401 error.
|
||||
checkBadMultiRecipientResponse(response, 401);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// This is the case where the recipient has not enabled stories.
|
||||
if (isStory) {
|
||||
// We are sending a story, so we ignore access checks.
|
||||
// this recipient has one device.
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
} else {
|
||||
// We are not sending a story so check access.
|
||||
if (authorize) {
|
||||
// If allowed, send a message to the recipient's one device.
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
} else {
|
||||
// If forbidden, return a 401 error.
|
||||
checkBadMultiRecipientResponse(response, 401);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SafeVarargs
|
||||
private static <K, V> Map<K, V> submap(Map<K, V> map, K... keys) {
|
||||
return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get));
|
||||
// Arguments here are: recipient-UUID, is-authorized?, is-story?
|
||||
private static Stream<Arguments> testMultiRecipientMessage() {
|
||||
return Stream.of(
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
|
||||
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
|
||||
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
|
||||
);
|
||||
}
|
||||
|
||||
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
|
||||
return
|
||||
Map.of(
|
||||
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
|
||||
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
Map.of(
|
||||
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
|
||||
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
|
||||
new PniServiceIdentifier(MULTI_DEVICE_PNI),
|
||||
Map.of(
|
||||
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
|
||||
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
|
||||
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
|
||||
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
|
||||
);
|
||||
}
|
||||
@Test
|
||||
void testMultiRecipientMessageToAccountsSomeOfWhichDoNotExist() throws Exception {
|
||||
UUID badUUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
|
||||
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(badUUID))).thenReturn(Optional.empty());
|
||||
|
||||
private record MultiRecipientMessageTestCase(
|
||||
Map<ServiceIdentifier, Map<Byte, Integer>> destinations,
|
||||
boolean authenticated,
|
||||
boolean story,
|
||||
int expectedStatus,
|
||||
int expectedSentMessages) {
|
||||
}
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1,
|
||||
new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(badUUID), (byte) 1, 1, new byte[48]));
|
||||
|
||||
@CartesianTest
|
||||
@CartesianTest.MethodFactory("testMultiRecipientMessageNoPni")
|
||||
void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception {
|
||||
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages());
|
||||
}
|
||||
Response response = resources
|
||||
.getJerseyTest()
|
||||
.target("/v1/messages/multi_recipient")
|
||||
.queryParam("online", true)
|
||||
.queryParam("ts", 1700000000000L)
|
||||
.queryParam("story", true)
|
||||
.queryParam("urgent", false)
|
||||
.request()
|
||||
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
|
||||
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
|
||||
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], true),
|
||||
MultiRecipientMessageProvider.MEDIA_TYPE));
|
||||
|
||||
private static ArgumentSets testMultiRecipientMessageNoPni() {
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
|
||||
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
|
||||
submap(
|
||||
targets,
|
||||
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
new AciServiceIdentifier(NONEXISTENT_UUID));
|
||||
|
||||
final boolean auth = true;
|
||||
final boolean unauth = false;
|
||||
final boolean story = true;
|
||||
final boolean notStory = false;
|
||||
|
||||
return ArgumentSets
|
||||
.argumentsForFirstParameter(
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1),
|
||||
new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0))
|
||||
.argumentsForNextParameter(false, true) // urgent
|
||||
.argumentsForNextParameter(false, true); // explicitIdentifiers
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
@CartesianTest.MethodFactory("testMultiRecipientMessagePni")
|
||||
void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception {
|
||||
testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages());
|
||||
}
|
||||
|
||||
private static ArgumentSets testMultiRecipientMessagePni() {
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
|
||||
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
|
||||
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
|
||||
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
|
||||
submap(
|
||||
targets,
|
||||
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
|
||||
new AciServiceIdentifier(MULTI_DEVICE_UUID),
|
||||
new PniServiceIdentifier(NONEXISTENT_UUID));
|
||||
|
||||
final boolean auth = true;
|
||||
final boolean unauth = false;
|
||||
final boolean story = true;
|
||||
final boolean notStory = false;
|
||||
|
||||
return ArgumentSets
|
||||
.argumentsForFirstParameter(
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 3),
|
||||
|
||||
new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 200, 1),
|
||||
new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2),
|
||||
new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 200, 2),
|
||||
new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 200, 3),
|
||||
new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0))
|
||||
.argumentsForNextParameter(false, true); // urgent
|
||||
checkGoodMultiRecipientResponse(response, 1);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@@ -1174,7 +1128,7 @@ class MessageControllerTest {
|
||||
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
|
||||
|
||||
Response response = resources
|
||||
@@ -1372,12 +1326,12 @@ class MessageControllerTest {
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier, final int regId1, final int regId2)
|
||||
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
|
||||
throws NotPushRegisteredException, InterruptedException {
|
||||
|
||||
final List<Recipient> recipients = List.of(
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, regId1, new byte[48]),
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, regId2, new byte[48]));
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
|
||||
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
|
||||
|
||||
// initialize our binary payload and create an input stream
|
||||
byte[] buffer = new byte[2048];
|
||||
@@ -1410,8 +1364,8 @@ class MessageControllerTest {
|
||||
|
||||
private static Stream<Arguments> sendMultiRecipientMessage404() {
|
||||
return Stream.of(
|
||||
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
|
||||
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
|
||||
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
|
||||
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
|
||||
}
|
||||
|
||||
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
|
||||
@@ -1419,6 +1373,14 @@ class MessageControllerTest {
|
||||
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
|
||||
}
|
||||
|
||||
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
|
||||
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
|
||||
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
|
||||
verify(messageSender, times(expectedCount)).sendMessage(any(), any(), any(), anyBoolean());
|
||||
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
|
||||
assert (smrmr.uuids404().isEmpty());
|
||||
}
|
||||
|
||||
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
|
||||
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
|
||||
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
|
||||
@@ -1451,4 +1413,64 @@ class MessageControllerTest {
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
private static Recipient genRecipient(Random rng) {
|
||||
UUID u1 = UUID.randomUUID(); // non-null
|
||||
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
|
||||
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
|
||||
byte[] perKeyBytes = new byte[48]; // size=48, non-null
|
||||
rng.nextBytes(perKeyBytes);
|
||||
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
|
||||
}
|
||||
|
||||
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
|
||||
ByteBuffer bb = ByteBuffer.wrap(bytes);
|
||||
writePayloadDeviceId(bb, expected);
|
||||
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
|
||||
long got = MultiRecipientMessageProvider.readVarint(stream);
|
||||
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testVarintPayload() throws Exception {
|
||||
Random rng = new Random();
|
||||
byte[] bytes = new byte[12];
|
||||
|
||||
// some static test cases
|
||||
for (byte i = 1; i <= 10; i++) {
|
||||
roundTripVarint(i, bytes);
|
||||
}
|
||||
roundTripVarint(Byte.MAX_VALUE, bytes);
|
||||
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
// we need to ensure positive device IDs
|
||||
byte start = (byte) rng.nextInt(128);
|
||||
if (start == 0L) {
|
||||
start = 1;
|
||||
}
|
||||
|
||||
// run the test for this case
|
||||
roundTripVarint(start, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = {true, false})
|
||||
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
|
||||
Random rng = new java.util.Random();
|
||||
List<Recipient> expected = new LinkedList<>();
|
||||
for(int i = 0; i < 100; i++) {
|
||||
expected.add(genRecipient(rng));
|
||||
}
|
||||
|
||||
byte[] buffer = new byte[100 + expected.size() * 100];
|
||||
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
|
||||
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
|
||||
// the provider ignores the headers, java reflection, etc. so we don't use those here.
|
||||
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
|
||||
List<Recipient> got = Arrays.asList(res.recipients());
|
||||
|
||||
assertEquals(expected, got);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -10,8 +10,6 @@ import static org.mockito.Mockito.doNothing;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
|
||||
import static org.mockito.internal.exceptions.Reporter.tooFewActualInvocations;
|
||||
import static org.mockito.internal.exceptions.Reporter.tooManyActualInvocations;
|
||||
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
|
||||
import static org.mockito.internal.invocation.InvocationMarker.markVerified;
|
||||
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
|
||||
@@ -28,7 +26,6 @@ import org.mockito.Mockito;
|
||||
import org.mockito.invocation.Invocation;
|
||||
import org.mockito.invocation.MatchableInvocation;
|
||||
import org.mockito.verification.VerificationMode;
|
||||
import org.mockito.internal.verification.Times;
|
||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||
@@ -174,17 +171,10 @@ public final class MockUtils {
|
||||
* this method
|
||||
*/
|
||||
public static VerificationMode exactly() {
|
||||
return exactly(1);
|
||||
}
|
||||
|
||||
/**
|
||||
* a combination of {@link #exactly()} and {@link org.mockito.Mockito#times(int)}, verifies that
|
||||
* there are exactly N invocations of this method, and all of them match the given specification
|
||||
*/
|
||||
public static VerificationMode exactly(int wantedCount) {
|
||||
return data -> {
|
||||
MatchableInvocation target = data.getTarget();
|
||||
final List<Invocation> allInvocations = data.getAllInvocations();
|
||||
List<Invocation> chunk = findInvocations(allInvocations, target);
|
||||
List<Invocation> otherInvocations = allInvocations.stream()
|
||||
.filter(target::hasSameMethod)
|
||||
.filter(Predicate.not(target::matches))
|
||||
@@ -194,7 +184,10 @@ public final class MockUtils {
|
||||
Invocation unverified = findFirstUnverified(otherInvocations);
|
||||
throw noMoreInteractionsWanted(unverified, (List) allInvocations);
|
||||
}
|
||||
Mockito.times(wantedCount).verify(data);
|
||||
if (chunk.isEmpty()) {
|
||||
throw wantedButNotInvoked(target);
|
||||
}
|
||||
markVerified(chunk.get(0), target);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user