Introduce "service identifiers"

This commit is contained in:
Jon Chambers
2023-07-21 09:34:10 -04:00
committed by GitHub
parent 4a6c7152cf
commit abb32bd919
39 changed files with 1304 additions and 588 deletions

View File

@@ -70,6 +70,8 @@ import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -869,10 +871,9 @@ class AccountControllerTest {
final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID();
when(accountsManager.getByAccountIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
when(accountsManager.getByPhoneNumberIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByPhoneNumberIdentifier(phoneNumberIdentifier)).thenReturn(Optional.of(account));
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(accountIdentifier))).thenReturn(Optional.of(account));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(phoneNumberIdentifier))).thenReturn(Optional.of(account));
when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(mock(RateLimiter.class));
@@ -884,7 +885,7 @@ class AccountControllerTest {
.getStatus()).isEqualTo(200);
assertThat(resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", phoneNumberIdentifier))
.target(String.format("/v1/accounts/account/PNI:%s", phoneNumberIdentifier))
.request()
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.head()
@@ -954,7 +955,7 @@ class AccountControllerTest {
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.get();
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.readEntity(AccountIdentifierResponse.class).uuid()).isEqualTo(uuid);
assertThat(response.readEntity(AccountIdentifierResponse.class).uuid().uuid()).isEqualTo(uuid);
}
@Test

View File

@@ -57,6 +57,8 @@ import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
@@ -77,7 +79,6 @@ class KeysControllerTest {
private static final UUID EXISTS_UUID = UUID.randomUUID();
private static final UUID EXISTS_PNI = UUID.randomUUID();
private static final String NOT_EXISTS_NUMBER = "+14152222220";
private static final UUID NOT_EXISTS_UUID = UUID.randomUUID();
private static final int SAMPLE_REGISTRATION_ID = 999;
@@ -212,12 +213,10 @@ class KeysControllerTest {
when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER);
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.getByE164(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount));
when(accounts.getByAccountIdentifier(EXISTS_UUID)).thenReturn(Optional.of(existsAccount));
when(accounts.getByPhoneNumberIdentifier(EXISTS_PNI)).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accounts.getByE164(NOT_EXISTS_NUMBER)).thenReturn(Optional.empty());
when(accounts.getByAccountIdentifier(NOT_EXISTS_UUID)).thenReturn(Optional.empty());
when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
@@ -384,7 +383,7 @@ class KeysControllerTest {
@Test
void validSingleRequestByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
@@ -404,7 +403,7 @@ class KeysControllerTest {
@Test
void validSingleRequestPqByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@@ -428,7 +427,7 @@ class KeysControllerTest {
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
@@ -451,7 +450,7 @@ class KeysControllerTest {
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString());
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_PNI))
.target(String.format("/v2/keys/PNI:%s/*", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();

View File

@@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@@ -29,6 +30,7 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -42,11 +44,13 @@ import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
@@ -69,6 +73,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
@@ -78,6 +83,8 @@ import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
@@ -91,11 +98,15 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
@@ -111,6 +122,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
@@ -191,11 +203,11 @@ class MessageControllerTest {
Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByAccountIdentifier(eq(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(MULTI_DEVICE_PNI)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
final DynamicDeliveryLatencyConfiguration deliveryLatencyConfiguration = mock(DynamicDeliveryLatencyConfiguration.class);
when(deliveryLatencyConfiguration.instrumentedVersions()).thenReturn(Collections.emptyMap());
@@ -310,7 +322,7 @@ class MessageControllerTest {
void testSingleDeviceCurrentByPni() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI))
.target(String.format("/v1/messages/PNI:%s", SINGLE_DEVICE_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
@@ -471,7 +483,7 @@ class MessageControllerTest {
void testMultiDeviceByPni() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI))
.target(String.format("/v1/messages/PNI:%s", MULTI_DEVICE_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"),
@@ -543,14 +555,14 @@ class MessageControllerTest {
OutgoingMessageEntity first = messages.get(0);
assertEquals(first.timestamp(), timestampOne);
assertEquals(first.guid(), messageGuidOne);
assertEquals(first.sourceUuid(), sourceUuid);
assertEquals(first.sourceUuid().uuid(), sourceUuid);
assertEquals(updatedPniOne, first.updatedPni());
if (receiveStories) {
OutgoingMessageEntity second = messages.get(1);
assertEquals(second.timestamp(), timestampTwo);
assertEquals(second.guid(), messageGuidTwo);
assertEquals(second.sourceUuid(), sourceUuid);
assertEquals(second.sourceUuid().uuid(), sourceUuid);
assertNull(second.updatedPni());
}
@@ -623,8 +635,8 @@ class MessageControllerTest {
.delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(AuthHelper.VALID_UUID), eq(1L),
eq(sourceUuid), eq(timestamp));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
response = resources.getJerseyTest()
.target(String.format("/v1/messages/uuid/%s", uuid2))
@@ -920,28 +932,32 @@ class MessageControllerTest {
} while (x != 0);
}
private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception {
long msb = r.getUuid().getMostSignificantBits();
long lsb = r.getUuid().getLeastSignificantBits();
bb.putLong(msb); // uuid (first 8 bytes)
bb.putLong(lsb); // uuid (last 8 bytes)
writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes)
bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes)
bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes)
private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(r.uuid().toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
}
writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
// first write the header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
bb.put(explicitIdentifiers
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
writeMultiPayloadRecipient(bb, it.next());
writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers);
}
// now write the actual message body (empty for now)
@@ -953,22 +969,22 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception {
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
final List<Recipient> recipients;
if (recipientUUID == MULTI_DEVICE_UUID) {
recipients = List.of(
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
);
} else {
recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
}
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
//InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@@ -1058,31 +1074,48 @@ class MessageControllerTest {
// Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false)
Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
);
}
@Test
void testMultiRecipientRedisBombProtection() throws Exception {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources
.getJerseyTest()
@@ -1094,7 +1127,7 @@ class MessageControllerTest {
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048]), MultiRecipientMessageProvider.MEDIA_TYPE));
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 422);
}
@@ -1118,22 +1151,22 @@ class MessageControllerTest {
@ParameterizedTest
@MethodSource
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception {
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) {
final Recipient r1;
if (known) {
r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]);
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]);
}
Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3);
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, useExplicitIdentifier);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@@ -1167,10 +1200,170 @@ class MessageControllerTest {
private static Stream<Arguments> testSendMultiRecipientMessageToUnknownAccounts() {
return Stream.of(
Arguments.of(true, true),
Arguments.of(true, false),
Arguments.of(false, true),
Arguments.of(false, false));
Arguments.of(true, true, false),
Arguments.of(true, false, false),
Arguments.of(false, true, false),
Arguments.of(false, false, false),
Arguments.of(true, true, true),
Arguments.of(true, false, true),
Arguments.of(false, true, true),
Arguments.of(false, false, true)
);
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessageMismatchedDevices(final ServiceIdentifier serviceIdentifier)
throws JsonProcessingException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
assertEquals(409, response.getStatus());
final List<AccountMismatchedDevices> mismatchedDevices =
SystemMapper.jsonMapper().readValue(response.readEntity(String.class),
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class));
assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))),
mismatchedDevices);
}
private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessageStaleDevices(final ServiceIdentifier serviceIdentifier) throws JsonProcessingException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1 + 1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2 + 1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
assertEquals(410, response.getStatus());
final List<AccountStaleDevices> staleDevices =
SystemMapper.jsonMapper().readValue(response.readEntity(String.class),
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountStaleDevices.class));
assertEquals(1, staleDevices.size());
assertEquals(serviceIdentifier, staleDevices.get(0).uuid());
assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices()));
}
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
throws NotPushRegisteredException, InterruptedException {
when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
doThrow(NotPushRegisteredException.class)
.when(messageSender).sendMessage(any(), any(), any(), anyBoolean());
// make the PUT request
final SendMultiRecipientMessageResponse response = invocationBuilder.put(entity, SendMultiRecipientMessageResponse.class);
assertEquals(List.of(serviceIdentifier), response.uuids404());
}
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
@@ -1185,7 +1378,7 @@ class MessageControllerTest {
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount);
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.getUUIDs404().isEmpty());
assert (smrmr.uuids404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
@@ -1226,7 +1419,7 @@ class MessageControllerTest {
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(u1, d1, dr1, perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
@@ -1258,8 +1451,9 @@ class MessageControllerTest {
}
}
@Test
void testMultiPayloadRoundtrip() throws Exception {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
Random rng = new java.util.Random();
List<Recipient> expected = new LinkedList<>();
for(int i = 0; i < 100; i++) {
@@ -1267,11 +1461,11 @@ class MessageControllerTest {
}
byte[] buffer = new byte[100 + expected.size() * 100];
InputStream entityStream = initializeMultiPayload(expected, buffer);
InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
// the provider ignores the headers, java reflection, etc. so we don't use those here.
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
List<Recipient> got = Arrays.asList(res.getRecipients());
List<Recipient> got = Arrays.asList(res.recipients());
assertEquals(expected, got);
}

View File

@@ -38,7 +38,6 @@ import java.util.Collections;
import java.util.HexFormat;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Executors;
@@ -90,6 +89,9 @@ import org.whispersystems.textsecuregcm.entities.CreateProfileRequest;
import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialProfileResponse;
import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes;
import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
@@ -202,6 +204,7 @@ class ProfileControllerTest {
Account capabilitiesAccount = mock(Account.class);
when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(capabilitiesAccount.getIdentityKey()).thenReturn(ACCOUNT_IDENTITY_KEY);
when(capabilitiesAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY);
when(capabilitiesAccount.isEnabled()).thenReturn(true);
@@ -209,20 +212,23 @@ class ProfileControllerTest {
when(capabilitiesAccount.isAnnouncementGroupSupported()).thenReturn(true);
when(capabilitiesAccount.isChangeNumberSupported()).thenReturn(true);
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByPhoneNumberIdentifier(AuthHelper.VALID_PNI_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO))).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO))).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByUsernameHash(USERNAME_HASH)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount));
when(profilesManager.get(eq(AuthHelper.VALID_UUID), eq("someversion"))).thenReturn(Optional.empty());
when(profilesManager.get(eq(AuthHelper.VALID_UUID_TWO), eq("validversion"))).thenReturn(Optional.of(new VersionedProfile(
"validversion", "validname", "profiles/validavatar", "emoji", "about", null, "validcommitmnet".getBytes())));
when(accountsManager.getByAccountIdentifier(AuthHelper.INVALID_UUID)).thenReturn(Optional.empty());
clearInvocations(rateLimiter);
clearInvocations(accountsManager);
clearInvocations(usernameRateLimiter);
@@ -308,14 +314,14 @@ class ProfileControllerTest {
@Test
void testProfileGetByPni() throws RateLimitExceededException {
final BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
.target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
assertThat(profile.getBadges()).isEmpty();
assertThat(profile.getUuid()).isEqualTo(AuthHelper.VALID_PNI_TWO);
assertThat(profile.getUuid()).isEqualTo(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO));
assertThat(profile.getCapabilities()).isNotNull();
assertThat(profile.isUnrestrictedUnidentifiedAccess()).isFalse();
assertThat(profile.getUnidentifiedAccess()).isNull();
@@ -342,7 +348,7 @@ class ProfileControllerTest {
@Test
void testProfileGetByPniUnidentified() throws RateLimitExceededException {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
.target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
@@ -836,7 +842,7 @@ class ProfileControllerTest {
assertThat(profile.getAboutEmoji()).isEqualTo("emoji");
assertThat(profile.getAvatar()).isEqualTo("profiles/validavatar");
assertThat(profile.getBaseProfileResponse().getCapabilities().gv1Migration()).isTrue();
assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID_TWO);
assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO));
assertThat(profile.getBaseProfileResponse().getBadges()).hasSize(1).element(0).has(new Condition<>(
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
@@ -927,7 +933,9 @@ class ProfileControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(ExpiringProfileKeyCredentialProfileResponse.class);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid())
.isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID));
assertThat(profile.getCredential()).isNull();
verify(zkProfileOperations, never()).issueExpiringProfileKeyCredential(any(), any(), any(), any());
@@ -1092,7 +1100,8 @@ class ProfileControllerTest {
.headers(authHeaders)
.get(ExpiringProfileKeyCredentialProfileResponse.class);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid())
.isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID));
assertThat(profile.getCredential()).isEqualTo(credentialResponse);
verify(zkProfileOperations).issueExpiringProfileKeyCredential(credentialRequest, AuthHelper.VALID_UUID, profileKeyCommitment, expiration);
@@ -1154,13 +1163,13 @@ class ProfileControllerTest {
void testBatchIdentityCheck() {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.json(new BatchIdentityCheckRequest(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID),
convertKeyToFingerprint(ACCOUNT_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO,
new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO),
convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO),
convertKeyToFingerprint(ACCOUNT_TWO_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID),
convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY))
))))) {
assertThat(response).isNotNull();
@@ -1170,17 +1179,14 @@ class ProfileControllerTest {
assertThat(identityCheckResponse.elements()).isNotNull().isEmpty();
}
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> {
if (AuthHelper.VALID_UUID.equals(element.aci())) {
return Objects.equals(ACCOUNT_IDENTITY_KEY, element.identityKey());
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) {
return Objects.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey());
} else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) {
return Objects.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey());
} else {
return false;
}
}, "is an expected UUID with the correct identity key");
final Map<ServiceIdentifier, IdentityKey> expectedIdentityKeys = Map.of(
new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY,
new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY,
new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO), ACCOUNT_TWO_IDENTITY_KEY);
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid =
new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())),
"is an expected UUID with the correct identity key");
final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@@ -1189,13 +1195,13 @@ class ProfileControllerTest {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.json(new BatchIdentityCheckRequest(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID),
convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO,
new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO),
convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO),
convertKeyToFingerprint(secondValidAciIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null,
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID),
convertKeyToFingerprint(invalidAciIdentityKey))
))))) {
assertThat(response).isNotNull();
@@ -1209,13 +1215,13 @@ class ProfileControllerTest {
}
final List<BatchIdentityCheckRequest.Element> largeElementList = new ArrayList<>(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint(invalidAciIdentityKey))));
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID), convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID), convertKeyToFingerprint(invalidAciIdentityKey))));
for (int i = 0; i < 900; i++) {
largeElementList.add(
new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(UUID.randomUUID()), convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
}
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
@@ -1233,27 +1239,25 @@ class ProfileControllerTest {
@Test
void testBatchIdentityCheckDeserialization() throws Exception {
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> {
if (AuthHelper.VALID_UUID.equals(element.aci())) {
return ACCOUNT_IDENTITY_KEY.equals(element.identityKey());
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) {
return ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY.equals(element.identityKey());
} else {
return false;
}
}, "is an expected UUID with the correct identity key");
final Map<ServiceIdentifier, IdentityKey> expectedIdentityKeys = Map.of(
new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY,
new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid =
new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())),
"is an expected UUID with the correct identity key");
// null properties are ok to omit
final String json = String.format("""
{
"elements": [
{ "aci": "%s", "fingerprint": "%s" },
{ "uuid": "%s", "fingerprint": "%s" },
{ "aci": "%s", "fingerprint": "%s" }
{ "uuid": "%s", "fingerprint": "%s" },
{ "uuid": "%s", "fingerprint": "%s" }
]
}
""", AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
"PNI:" + AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
@@ -1277,50 +1281,34 @@ class ProfileControllerTest {
@ParameterizedTest
@MethodSource
void testBatchIdentityCheckDeserializationBadRequest(final String json) {
void testBatchIdentityCheckDeserializationBadRequest(final String json, final int expectedStatus) {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.entity(json, "application/json"))) {
assertThat(response).isNotNull();
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.getStatus()).isEqualTo(expectedStatus);
}
}
static Stream<Arguments> testBatchIdentityCheckDeserializationBadRequest() {
return Stream.of(
Arguments.of( // aci and uuid cannot both be null
"""
{
"elements": [
{ "aci": null, "uuid": null, "fingerprint": "%s" }
]
}
"""),
Arguments.of( // an empty string is also invalid
"""
{
"elements": [
{ "aci": "", "uuid": null, "fingerprint": "%s" }
]
}
"""
),
Arguments.of( // as is a blank string
"""
{
"elements": [
{ "aci": null, "uuid": " ", "fingerprint": "%s" }
]
}
"""),
Arguments.of( // aci and uuid cannot both be non-null
String.format("""
{
"elements": [
{ "aci": "%s", "uuid": "%s", "fingerprint": "%s" }
]
}
""", AuthHelper.VALID_UUID, AuthHelper.VALID_PNI,
Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))))
{
"elements": [
{ "uuid": null, "fingerprint": "%s" }
]
}
""", Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))),
422),
Arguments.of( // a blank string is invalid
String.format("""
{
"elements": [
{ "uuid": " ", "fingerprint": "%s" }
]
}
""", Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))),
400)
);
}

View File

@@ -9,20 +9,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Random;
import java.util.UUID;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class OutgoingMessageEntityTest {
@ParameterizedTest
@MethodSource
void roundTripThroughEnvelope(@Nullable final UUID sourceUuid, @Nullable final UUID updatedPni) {
@CartesianTest
@CartesianTest.MethodFactory("roundTripThroughEnvelope")
void roundTripThroughEnvelope(@Nullable final ServiceIdentifier sourceIdentifier,
final ServiceIdentifier destinationIdentifier,
@Nullable final UUID updatedPni) {
final byte[] messageContent = new byte[16];
new Random().nextBytes(messageContent);
@@ -35,9 +39,9 @@ class OutgoingMessageEntityTest {
UUID.randomUUID(),
MessageProtos.Envelope.Type.CIPHERTEXT_VALUE,
messageTimestamp,
UUID.randomUUID(),
sourceUuid != null ? (int) Device.MASTER_ID : 0,
UUID.randomUUID(),
sourceIdentifier,
sourceIdentifier != null ? (int) Device.MASTER_ID : 0,
destinationIdentifier,
updatedPni,
messageContent,
serverTimestamp,
@@ -48,11 +52,14 @@ class OutgoingMessageEntityTest {
assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope()));
}
private static Stream<Arguments> roundTripThroughEnvelope() {
return Stream.of(
Arguments.of(UUID.randomUUID(), UUID.randomUUID()),
Arguments.of(UUID.randomUUID(), null),
Arguments.of(null, UUID.randomUUID()));
@SuppressWarnings("unused")
static ArgumentSets roundTripThroughEnvelope() {
return ArgumentSets.argumentsForFirstParameter(new AciServiceIdentifier(UUID.randomUUID()),
new PniServiceIdentifier(UUID.randomUUID()),
null)
.argumentsForNextParameter(new AciServiceIdentifier(UUID.randomUUID()),
new PniServiceIdentifier(UUID.randomUUID()))
.argumentsForNextParameter(UUID.randomUUID(), null);
}
@Test
@@ -71,7 +78,7 @@ class OutgoingMessageEntityTest {
IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA");
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
UUID.randomUUID(),
new AciServiceIdentifier(UUID.randomUUID()),
account,
123L,
System.currentTimeMillis(),

View File

@@ -0,0 +1,77 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.nio.ByteBuffer;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class AciServiceIdentifierTest {
@Test
void identityType() {
assertEquals(IdentityType.ACI, new AciServiceIdentifier(UUID.randomUUID()).identityType());
}
@Test
void toServiceIdentifierString() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid.toString(), new AciServiceIdentifier(uuid).toServiceIdentifierString());
}
@Test
void toCompactByteArray() {
final UUID uuid = UUID.randomUUID();
assertArrayEquals(UUIDUtil.toBytes(uuid), new AciServiceIdentifier(uuid).toCompactByteArray());
}
@Test
void toFixedWidthByteArray() {
final UUID uuid = UUID.randomUUID();
final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17);
expectedBytesBuffer.put((byte) 0x00);
expectedBytesBuffer.putLong(uuid.getMostSignificantBits());
expectedBytesBuffer.putLong(uuid.getLeastSignificantBits());
expectedBytesBuffer.flip();
assertArrayEquals(expectedBytesBuffer.array(), new AciServiceIdentifier(uuid).toFixedWidthByteArray());
}
@Test
void valueOf() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, AciServiceIdentifier.valueOf(uuid.toString()).uuid());
assertEquals(uuid, AciServiceIdentifier.valueOf("ACI:" + uuid).uuid());
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("Not a valid UUID"));
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("PNI:" + uuid));
}
@Test
void fromBytes() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, AciServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid)).uuid());
final byte[] prefixedBytes = new byte[17];
prefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16);
assertEquals(uuid, AciServiceIdentifier.fromBytes(prefixedBytes).uuid());
prefixedBytes[0] = 0x01;
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.fromBytes(prefixedBytes));
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.nio.ByteBuffer;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class PniServiceIdentifierTest {
@Test
void identityType() {
assertEquals(IdentityType.PNI, new PniServiceIdentifier(UUID.randomUUID()).identityType());
}
@Test
void toServiceIdentifierString() {
final UUID uuid = UUID.randomUUID();
assertEquals("PNI:" + uuid, new PniServiceIdentifier(uuid).toServiceIdentifierString());
}
@Test
void toByteArray() {
final UUID uuid = UUID.randomUUID();
final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17);
expectedBytesBuffer.put((byte) 0x01);
expectedBytesBuffer.putLong(uuid.getMostSignificantBits());
expectedBytesBuffer.putLong(uuid.getLeastSignificantBits());
expectedBytesBuffer.flip();
assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toCompactByteArray());
assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toFixedWidthByteArray());
}
@Test
void valueOf() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, PniServiceIdentifier.valueOf("PNI:" + uuid).uuid());
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf(uuid.toString()));
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("Not a valid UUID"));
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("ACI:" + uuid));
}
@Test
void fromBytes() {
final UUID uuid = UUID.randomUUID();
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid)));
final byte[] prefixedBytes = new byte[17];
prefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16);
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(prefixedBytes));
prefixedBytes[0] = 0x01;
assertEquals(uuid, PniServiceIdentifier.fromBytes(prefixedBytes).uuid());
}
}

View File

@@ -0,0 +1,87 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class ServiceIdentifierTest {
@ParameterizedTest
@MethodSource
void valueOf(final String identifierString, final IdentityType expectedIdentityType, final UUID expectedUuid) {
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.valueOf(identifierString);
assertEquals(expectedIdentityType, serviceIdentifier.identityType());
assertEquals(expectedUuid, serviceIdentifier.uuid());
}
private static Stream<Arguments> valueOf() {
final UUID uuid = UUID.randomUUID();
return Stream.of(
Arguments.of(uuid.toString(), IdentityType.ACI, uuid),
Arguments.of("ACI:" + uuid, IdentityType.ACI, uuid),
Arguments.of("PNI:" + uuid, IdentityType.PNI, uuid));
}
@ParameterizedTest
@ValueSource(strings = {"Not a valid UUID", "BAD:a9edc243-3e93-45d4-95c6-e3a84cd4a254"})
void valueOfIllegalArgument(final String identifierString) {
assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.valueOf(identifierString));
}
@ParameterizedTest
@MethodSource
void fromBytes(final byte[] bytes, final IdentityType expectedIdentityType, final UUID expectedUuid) {
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromBytes(bytes);
assertEquals(expectedIdentityType, serviceIdentifier.identityType());
assertEquals(expectedUuid, serviceIdentifier.uuid());
}
private static Stream<Arguments> fromBytes() {
final UUID uuid = UUID.randomUUID();
final byte[] aciPrefixedBytes = new byte[17];
aciPrefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, aciPrefixedBytes, 1, 16);
final byte[] pniPrefixedBytes = new byte[17];
pniPrefixedBytes[0] = 0x01;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, pniPrefixedBytes, 1, 16);
return Stream.of(
Arguments.of(UUIDUtil.toBytes(uuid), IdentityType.ACI, uuid),
Arguments.of(aciPrefixedBytes, IdentityType.ACI, uuid),
Arguments.of(pniPrefixedBytes, IdentityType.PNI, uuid));
}
@ParameterizedTest
@MethodSource
void fromBytesIllegalArgument(final byte[] bytes) {
assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.fromBytes(bytes));
}
private static Stream<Arguments> fromBytesIllegalArgument() {
final byte[] invalidPrefixBytes = new byte[17];
invalidPrefixBytes[0] = (byte) 0xff;
return Stream.of(
Arguments.of(new byte[0]),
Arguments.of(new byte[15]),
Arguments.of(new byte[18]),
Arguments.of(invalidPrefixBytes));
}
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.metrics;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -21,6 +22,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
class MessageMetricsTest {
@@ -35,6 +39,9 @@ class MessageMetricsTest {
void setup() {
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
when(account.isIdentifiedBy(any())).thenReturn(false);
when(account.isIdentifiedBy(new AciServiceIdentifier(aci))).thenReturn(true);
when(account.isIdentifiedBy(new PniServiceIdentifier(pni))).thenReturn(true);
Metrics.globalRegistry.clear();
simpleMeterRegistry = new SimpleMeterRegistry();
Metrics.globalRegistry.add(simpleMeterRegistry);
@@ -49,46 +56,46 @@ class MessageMetricsTest {
@Test
void measureAccountOutgoingMessageUuidMismatches() {
final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(aci);
final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(new AciServiceIdentifier(aci));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToAci);
Optional<Counter> counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty());
final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(pni);
final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(new PniServiceIdentifier(pni));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToPni);
counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty());
final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(otherUuid);
final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(new AciServiceIdentifier(otherUuid));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToOtherUuid);
counter = findCounter(simpleMeterRegistry);
assertEquals(1.0, counter.map(Counter::count).orElse(0.0));
}
private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) {
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false, null);
private OutgoingMessageEntity createOutgoingMessageEntity(final ServiceIdentifier destinationIdentifier) {
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationIdentifier, null, new byte[]{}, 1, true, false, null);
}
@Test
void measureAccountEnvelopeUuidMismatches() {
final MessageProtos.Envelope envelopeToAci = createEnvelope(aci);
final MessageProtos.Envelope envelopeToAci = createEnvelope(new AciServiceIdentifier(aci));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToAci);
Optional<Counter> counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty());
final MessageProtos.Envelope envelopeToPni = createEnvelope(pni);
final MessageProtos.Envelope envelopeToPni = createEnvelope(new PniServiceIdentifier(pni));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToPni);
counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty());
final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(otherUuid);
final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(new AciServiceIdentifier(otherUuid));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToOtherUuid);
counter = findCounter(simpleMeterRegistry);
@@ -101,11 +108,11 @@ class MessageMetricsTest {
assertEquals(1.0, counter.map(Counter::count).orElse(0.0));
}
private MessageProtos.Envelope createEnvelope(UUID destinationUuid) {
private MessageProtos.Envelope createEnvelope(ServiceIdentifier destinationIdentifier) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
if (destinationUuid != null) {
builder.setDestinationUuid(destinationUuid.toString());
if (destinationIdentifier != null) {
builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString());
}
return builder.build();

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -61,6 +62,8 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -208,6 +211,21 @@ class AccountsManagerTest {
mock(Clock.class));
}
@Test
void testGetByServiceIdentifier() {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(commands.get(eq("Account3::" + aci))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent());
assertTrue(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni)).isPresent());
assertFalse(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(pni)).isPresent());
assertFalse(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(aci)).isPresent());
}
@Test
void testGetAccountByNumberInCache() {
UUID uuid = UUID.randomUUID();

View File

@@ -55,6 +55,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -225,7 +226,7 @@ class WebSocketConnectionTest {
verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId),
eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(accountUuid), eq(deviceId), eq(senderOneUuid),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)),
eq(2222L));
connection.stop();
@@ -369,7 +370,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account.getUuid()), eq(deviceId), eq(senderTwoUuid),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getTimestamp()));
connection.stop();