Add routing for stories.

This commit is contained in:
erik-signal
2022-10-05 10:32:10 -04:00
committed by Erik Osheim
parent c2ab72c77e
commit 966c3a8f47
20 changed files with 425 additions and 86 deletions

View File

@@ -32,15 +32,21 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@@ -62,11 +68,13 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -80,6 +88,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.Stories;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@@ -87,10 +96,20 @@ class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final int SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
@@ -116,6 +135,7 @@ class MessageControllerTest {
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(MultiRecipientMessageProvider.class)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
@@ -125,18 +145,18 @@ class MessageControllerTest {
@BeforeEach
void setup() {
final List<Device> singleDeviceList = List.of(
generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
);
final List<Device> multiDeviceList = List.of(
generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
);
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes());
internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes());
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES);
internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
@@ -171,7 +191,8 @@ class MessageControllerTest {
rateLimiters,
rateLimiter,
pushNotificationManager,
reportMessageManager
reportMessageManager,
multiRecipientMessageExecutor
);
}
@@ -270,7 +291,7 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
@@ -412,8 +433,9 @@ class MessageControllerTest {
verifyNoMoreInteractions(messageSender);
}
@Test
void testGetMessages() {
@ParameterizedTest
@MethodSource
void testGetMessages(boolean receiveStories) {
final long timestampOne = 313377;
final long timestampTwo = 313388;
@@ -424,19 +446,15 @@ class MessageControllerTest {
final UUID updatedPniOne = UUID.randomUUID();
List<Envelope> messages = List.of(
List<Envelope> envelopes = List.of(
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2,
AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0),
AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false),
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2,
AuthHelper.VALID_UUID, null, null, 0)
AuthHelper.VALID_UUID, null, null, 0, true)
);
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream()
.map(OutgoingMessageEntity::fromEnvelope)
.toList(), false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
.thenReturn(new Pair<>(envelopes, false));
final String userAgent = "Test-UA";
@@ -444,27 +462,39 @@ class MessageControllerTest {
resources.getJerseyTest().target("/v1/messages/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
.header("USer-Agent", userAgent)
.accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class);
assertEquals(response.messages().size(), 2);
List<OutgoingMessageEntity> messages = response.messages();
int expectedSize = receiveStories ? 2 : 1;
assertEquals(expectedSize, messages.size());
assertEquals(response.messages().get(0).timestamp(), timestampOne);
assertEquals(response.messages().get(1).timestamp(), timestampTwo);
OutgoingMessageEntity first = messages.get(0);
assertEquals(first.timestamp(), timestampOne);
assertEquals(first.guid(), messageGuidOne);
assertEquals(first.sourceUuid(), sourceUuid);
assertEquals(updatedPniOne, first.updatedPni());
assertEquals(response.messages().get(0).guid(), messageGuidOne);
assertEquals(response.messages().get(1).guid(), messageGuidTwo);
assertEquals(response.messages().get(0).sourceUuid(), sourceUuid);
assertEquals(response.messages().get(1).sourceUuid(), sourceUuid);
assertEquals(updatedPniOne, response.messages().get(0).updatedPni());
assertNull(response.messages().get(1).updatedPni());
if (receiveStories) {
OutgoingMessageEntity second = messages.get(1);
assertEquals(second.timestamp(), timestampTwo);
assertEquals(second.guid(), messageGuidTwo);
assertEquals(second.sourceUuid(), sourceUuid);
assertNull(second.updatedPni());
}
verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent);
}
private static Stream<Arguments> testGetMessages() {
return Stream.of(
Arguments.of(true),
Arguments.of(false)
);
}
@Test
void testGetMessagesBadAuth() {
final long timestampOne = 313377;
@@ -644,9 +674,9 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes()))
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true,
List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, false,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE));
@@ -686,14 +716,166 @@ class MessageControllerTest {
);
}
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception {
bb.putLong(msb); // uuid (first 8 bytes)
bb.putLong(lsb); // uuid (last 8 bytes)
int x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
bb.put((byte)(x & 0x7f));
x = x >>> 7;
} while (x != 0);
bb.putShort((short) regId); // registration id short
bb.put(new byte[48]); // key material (48 bytes)
}
private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
// determine how many recipient/device pairs we will be writing
int count;
if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; }
else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; }
else { throw new Exception("unknown UUID: " + recipientUUID); }
// first write the header header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
bb.put((byte)count); // count varint, # of active devices for this user
long msb = recipientUUID.getMostSignificantBits();
long lsb = recipientUUID.getLeastSignificantBits();
// write the recipient data for each recipient/device pair
if (recipientUUID == MULTI_DEVICE_UUID) {
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1);
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2);
} else {
writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1);
}
// now write the actual message body (empty for now)
bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here)
// return the input stream
return new ByteArrayInputStream(buffer, 0, bb.position());
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory) throws Exception {
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipientUUID, buffer);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
Invocation.Builder bldr = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", isStory)
.request()
.header("User-Agent", "FIXME");
// add access header if needed
if (authorize) {
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
// We have a 2x2x2 grid of possible situations based on:
// - recipient enabled stories?
// - sender is authorized?
// - message is a story?
if (recipientUUID == MULTI_DEVICE_UUID) {
// This is the case where the recipient has enabled stories.
if(isStory) {
// We are sending a story, so we ignore access checks and expect this
// to go out to both the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// We are not sending a story, so we need to do access checks.
if (authorize) {
// When authorized we send a message to the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// When forbidden, we return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} else {
// This is the case where the recipient has not enabled stories.
if (isStory) {
// We are sending a story, so we ignore access checks.
// this recipient has one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// We are not sending a story so check access.
if (authorize) {
// If allowed, send a message to the recipient's one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// If forbidden, return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
}
}
// Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false)
);
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
verify(multiRecipientMessageExecutor, never()).invokeAll(any());
}
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount);
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.getUUIDs404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setStory(story)
.setServerGuid(guid.toString());
if (sourceUuid != null) {

View File

@@ -36,7 +36,8 @@ class OutgoingMessageEntityTest {
updatedPni,
messageContent,
serverTimestamp,
true);
true,
false);
assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope()));
}

View File

@@ -70,7 +70,7 @@ class MessageMetricsTest {
}
private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) {
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true);
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false);
}
@Test