diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 9db413a67..b034a2c9c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -700,7 +700,7 @@ public class WhisperServerService extends Application syncMessageSenderDeviceId, @Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException { - if (!destination.isIdentifiedBy(destinationIdentifier)) { - throw new IllegalArgumentException("Destination account not identified by destination service identifier"); - } - final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); - if (messagesByDeviceId.isEmpty()) { - Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME, - Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment(); - } - - final byte excludedDeviceId; - if (syncMessageSenderDeviceId.isPresent()) { - if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) || - !destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { - - throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender"); - } - excludedDeviceId = syncMessageSenderDeviceId.get(); - } else { - if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) && - destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { - - throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender"); - } - excludedDeviceId = NO_EXCLUDED_DEVICE_ID; - } - - final Optional maybeMismatchedDevices = getMismatchedDevices(destination, + validateIndividualMessageBundle(destination, destinationIdentifier, + messagesByDeviceId, registrationIdsByDeviceId, - excludedDeviceId); - - if (maybeMismatchedDevices.isPresent()) { - throw new MismatchedDevicesException(maybeMismatchedDevices.get()); - } - - validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), userAgent); + syncMessageSenderDeviceId, + platformTag); messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId) .forEach((deviceId, destinationPresent) -> { @@ -206,7 +170,9 @@ public class MessageSender { final boolean isUrgent, @Nullable final String userAgent) throws MultiRecipientMismatchedDevicesException, MessageTooLargeException { - validateMultiRecipientMessageContentLength(multiRecipientMessage, isStory, userAgent); + final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); + + validateMultiRecipientMessageContentLength(multiRecipientMessage, isStory, platformTag); final Map mismatchedDevicesByServiceIdentifier = new HashMap<>(); @@ -252,24 +218,105 @@ public class MessageSender { SEALED_SENDER_TAG_NAME, "true", SYNC_MESSAGE_TAG_NAME, "false", MULTI_RECIPIENT_TAG_NAME, "true") - .and(UserAgentTagUtil.getPlatformTag(userAgent)); + .and(platformTag); Metrics.counter(SEND_COUNTER_NAME, tags).increment(); }))) .thenRun(Util.NOOP); } + /** + * Validates that a bundle of messages destined for an individual account is well-formed and may be delivered. Note + * that all checks performed by this method are also performed by + * {@link #sendMessages(Account, ServiceIdentifier, Map, Map, Optional, String)}; callers should only invoke this + * method if they need to verify that a bundle of individual messages is valid before trying to send the + * messages (i.e. if the caller must take some other action in conjunction with sending the messages and cannot + * reverse that action if message sending fails). + * + * @param destination the account to which to send messages + * @param destinationIdentifier the service identifier to which the messages are addressed + * @param messagesByDeviceId a map of device IDs to message payloads + * @param registrationIdsByDeviceId a map of device IDs to device registration IDs + * @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the + * caller's own account), contains the ID of the device that sent the message + * @param userAgent the User-Agent string for the sender; may be {@code null} if not known + * + * @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required + * devices, contained messages for devices not linked to the destination account, or devices with outdated + * registration IDs + * @throws MessageTooLargeException if the given message payload is too large + */ + public static void validateIndividualMessageBundle(final Account destination, + final ServiceIdentifier destinationIdentifier, + final Map messagesByDeviceId, + final Map registrationIdsByDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional syncMessageSenderDeviceId, + @Nullable final String userAgent) throws MessageTooLargeException, MismatchedDevicesException { + + validateIndividualMessageBundle(destination, + destinationIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + syncMessageSenderDeviceId, + UserAgentTagUtil.getPlatformTag(userAgent)); + } + + private static void validateIndividualMessageBundle(final Account destination, + final ServiceIdentifier destinationIdentifier, + final Map messagesByDeviceId, + final Map registrationIdsByDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional syncMessageSenderDeviceId, + final Tag platformTag) throws MismatchedDevicesException, MessageTooLargeException { + + if (!destination.isIdentifiedBy(destinationIdentifier)) { + throw new IllegalArgumentException("Destination account not identified by destination service identifier"); + } + + if (messagesByDeviceId.isEmpty()) { + Metrics.counter(EMPTY_MESSAGE_LIST_COUNTER_NAME, + Tags.of(SYNC_MESSAGE_TAG_NAME, String.valueOf(syncMessageSenderDeviceId.isPresent())).and(platformTag)).increment(); + } + + final byte excludedDeviceId; + if (syncMessageSenderDeviceId.isPresent()) { + if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) || + !destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { + + throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender"); + } + excludedDeviceId = syncMessageSenderDeviceId.get(); + } else { + if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) && + destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { + + throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender"); + } + excludedDeviceId = NO_EXCLUDED_DEVICE_ID; + } + + final Optional maybeMismatchedDevices = getMismatchedDevices(destination, + destinationIdentifier, + registrationIdsByDeviceId, + excludedDeviceId); + + if (maybeMismatchedDevices.isPresent()) { + throw new MismatchedDevicesException(maybeMismatchedDevices.get()); + } + + validateIndividualMessageContentLength(messagesByDeviceId.values(), syncMessageSenderDeviceId.isPresent(), platformTag); + } + @VisibleForTesting static void validateContentLength(final int contentLength, final boolean isMultiRecipientMessage, final boolean isSyncMessage, final boolean isStory, - final String userAgent) throws MessageTooLargeException { + final Tag platformTag) throws MessageTooLargeException { final boolean oversize = contentLength > MAX_MESSAGE_SIZE; DistributionSummary.builder(CONTENT_SIZE_DISTRIBUTION_NAME) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + .tags(Tags.of(platformTag, Tag.of("oversize", String.valueOf(oversize)), Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), Tag.of("syncMessage", String.valueOf(isSyncMessage)), @@ -279,7 +326,7 @@ public class MessageSender { .record(contentLength); if (oversize) { - Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + Metrics.counter(REJECT_OVERSIZE_MESSAGE_COUNTER_NAME, Tags.of(platformTag, Tag.of("multiRecipientMessage", String.valueOf(isMultiRecipientMessage)), Tag.of("syncMessage", String.valueOf(isSyncMessage)), Tag.of("story", String.valueOf(isStory)))) @@ -330,27 +377,27 @@ public class MessageSender { private static void validateIndividualMessageContentLength(final Iterable messages, final boolean isSyncMessage, - @Nullable final String userAgent) throws MessageTooLargeException { + final Tag platformTag) throws MessageTooLargeException { for (final Envelope message : messages) { MessageSender.validateContentLength(message.getContent().size(), false, isSyncMessage, message.getStory(), - userAgent); + platformTag); } } private static void validateMultiRecipientMessageContentLength(final SealedSenderMultiRecipientMessage multiRecipientMessage, final boolean isStory, - @Nullable final String userAgent) throws MessageTooLargeException { + final Tag platformTag) throws MessageTooLargeException { for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { MessageSender.validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, false, isStory, - userAgent); + platformTag); } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 9471bb482..4cf04b624 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -4,13 +4,11 @@ */ package org.whispersystems.textsecuregcm.storage; -import com.google.protobuf.ByteString; import java.time.Clock; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; -import org.apache.commons.lang3.ObjectUtils; import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +18,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; -import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageTooLargeException; @@ -51,48 +49,57 @@ public class ChangeNumberManager { final String senderUserAgent) throws InterruptedException, MismatchedDevicesException, MessageTooLargeException { + final long serverTimestamp = clock.millis(); + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI)); + + // Note that these for-validation envelopes do NOT have the "updated PNI" field set, and we'll need to populate that + // after actually changing the account's number. + final Map messagesByDeviceId = deviceMessages.stream() + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> message.toEnvelope(serviceIdentifier, + serviceIdentifier, + Device.PRIMARY_ID, + serverTimestamp, + false, + false, + true, + null, + clock))); + + final Map registrationIdsByDeviceId = deviceMessages.stream() + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); + + // Make sure we can plausibly deliver the messages to other devices on the account before making any changes to the + // account itself + if (!messagesByDeviceId.isEmpty()) { + MessageSender.validateIndividualMessageBundle(account, + serviceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + Optional.of(Device.PRIMARY_ID), + senderUserAgent); + } + final Account updatedAccount = accountsManager.changeNumber( account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds); - sendDeviceMessages(updatedAccount, deviceMessages, senderUserAgent); - - return updatedAccount; - } - - private void sendDeviceMessages(final Account account, - final List deviceMessages, - final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { - try { - final long serverTimestamp = clock.millis(); - final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); + // Now that we've actually updated the account, populate the "updated PNI" field on all envelopes + final String updatedPniString = updatedAccount.getIdentifier(IdentityType.PNI).toString(); - final Map messagesByDeviceId = deviceMessages.stream() - .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder() - .setType(Envelope.Type.forNumber(message.type())) - .setClientTimestamp(serverTimestamp) - .setServerTimestamp(serverTimestamp) - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) - .setContent(ByteString.copyFrom(message.content())) - .setSourceServiceId(serviceIdentifier.toServiceIdentifierString()) - .setSourceDevice(Device.PRIMARY_ID) - .setUpdatedPni(account.getPhoneNumberIdentifier().toString()) - .setUrgent(true) - .setEphemeral(false) - .build())); + messagesByDeviceId.replaceAll((deviceId, envelope) -> + envelope.toBuilder().setUpdatedPni(updatedPniString).build()); - final Map registrationIdsByDeviceId = deviceMessages.stream() - .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); - - messageSender.sendMessages(account, + messageSender.sendMessages(updatedAccount, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, Optional.of(Device.PRIMARY_ID), senderUserAgent); } catch (final RuntimeException e) { - logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); + logger.warn("Changed number but could not send all device messages for {}", account.getIdentifier(IdentityType.ACI), e); throw e; } + + return updatedAccount; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index 86a62432b..4e1cac624 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.time.Clock; import java.util.UUID; import javax.annotation.Nullable; import org.junit.jupiter.api.Test; @@ -75,7 +76,8 @@ class OutgoingMessageEntityTest { false, false, true, - reportSpamToken); + reportSpamToken, + Clock.systemUTC()); MessageProtos.Envelope envelope = baseEnvelope.toBuilder().setServerGuid(UUID.randomUUID().toString()).build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 8ff5e892b..fefa12981 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -18,6 +18,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import com.google.protobuf.ByteString; +import io.micrometer.core.instrument.Tag; import java.util.Collections; import java.util.List; import java.util.Map; @@ -25,8 +27,10 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -38,31 +42,30 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper; import org.whispersystems.textsecuregcm.tests.util.TestRecipient; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; class MessageSenderTest { private MessagesManager messagesManager; private PushNotificationManager pushNotificationManager; private MessageSender messageSender; - private ExperimentEnrollmentManager experimentEnrollmentManager; @BeforeEach void setUp() { messagesManager = mock(MessagesManager.class); pushNotificationManager = mock(PushNotificationManager.class); - experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); - messageSender = new MessageSender(messagesManager, pushNotificationManager, experimentEnrollmentManager); + messageSender = new MessageSender(messagesManager, pushNotificationManager); } @@ -258,13 +261,183 @@ class MessageSenderTest { mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier()); } + @ParameterizedTest + @MethodSource + void validateIndividualMessageBundle(final Account destination, + final ServiceIdentifier destinationIdentifier, + final Map messagesByDeviceId, + final Map registrationIdsByDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional syncMessageSenderDeviceId, + @Nullable final Class expectedExceptionClass) { + + final Executable validateIndividualMessageBundle = () -> MessageSender.validateIndividualMessageBundle(destination, + destinationIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + syncMessageSenderDeviceId, + "Signal/Test"); + + if (expectedExceptionClass != null) { + assertThrows(expectedExceptionClass, validateIndividualMessageBundle); + } else { + assertDoesNotThrow(validateIndividualMessageBundle); + } + } + + private static List validateIndividualMessageBundle() { + final ServiceIdentifier destinationIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + final byte primaryDeviceId = Device.PRIMARY_ID; + final byte linkedDeviceId = primaryDeviceId + 1; + + final int primaryDeviceRegistrationId = 17; + final int linkedDeviceRegistrationId = primaryDeviceRegistrationId + 1; + + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(primaryDeviceId); + when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryDeviceRegistrationId); + + final Device linkedDevice = mock(Device.class); + when(linkedDevice.getId()).thenReturn(linkedDeviceId); + when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceRegistrationId); + + final Account destination = mock(Account.class); + when(destination.isIdentifiedBy(any())).thenReturn(false); + when(destination.isIdentifiedBy(destinationIdentifier)).thenReturn(true); + when(destination.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); + when(destination.getDevice(anyByte())).thenReturn(Optional.empty()); + when(destination.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice)); + when(destination.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice)); + + return List.of( + Arguments.argumentSet("Valid", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(null, 16), + linkedDeviceId, generateEnvelope(null, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId), + Optional.empty(), + null), + + Arguments.argumentSet("Mismatched service ID", + destination, + new AciServiceIdentifier(UUID.randomUUID()), + Map.of( + primaryDeviceId, generateEnvelope(null, 16), + linkedDeviceId, generateEnvelope(null, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId), + Optional.empty(), + IllegalArgumentException.class), + + Arguments.argumentSet("Sync message without source on all messages", + destination, + destinationIdentifier, + Map.of(linkedDeviceId, generateEnvelope(null, 16)), + Map.of(linkedDeviceId, linkedDeviceRegistrationId), + Optional.of(primaryDevice), + IllegalArgumentException.class), + + Arguments.argumentSet("Sync message to other account", + destination, + destinationIdentifier, + Map.of(linkedDeviceId, generateEnvelope(new AciServiceIdentifier(UUID.randomUUID()), 16)), + Map.of(linkedDeviceId, linkedDeviceRegistrationId), + Optional.of(primaryDevice), + IllegalArgumentException.class), + + Arguments.argumentSet("Sync message to other account", + destination, + destinationIdentifier, + Map.of(linkedDeviceId, generateEnvelope(new AciServiceIdentifier(UUID.randomUUID()), 16)), + Map.of(linkedDeviceId, linkedDeviceRegistrationId), + Optional.of(primaryDevice), + IllegalArgumentException.class), + + Arguments.argumentSet("Non-sync message addressed to sender", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(destinationIdentifier, 16), + linkedDeviceId, generateEnvelope(destinationIdentifier, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId), + Optional.empty(), + IllegalArgumentException.class), + + Arguments.argumentSet("Non-sync message addressed to sender", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(destinationIdentifier, 16), + linkedDeviceId, generateEnvelope(destinationIdentifier, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId), + Optional.empty(), + IllegalArgumentException.class), + + Arguments.argumentSet("Mismatched devices in message set", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(null, 16), + linkedDeviceId + 1, generateEnvelope(null, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId + 1, linkedDeviceRegistrationId), + Optional.empty(), + MismatchedDevicesException.class), + + Arguments.argumentSet("Mismatched registration IDs", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(null, 16), + linkedDeviceId, generateEnvelope(null, 16)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId + 1), + Optional.empty(), + MismatchedDevicesException.class), + + Arguments.argumentSet("Oversized message", + destination, + destinationIdentifier, + Map.of( + primaryDeviceId, generateEnvelope(null, MessageSender.MAX_MESSAGE_SIZE + 1), + linkedDeviceId, generateEnvelope(null, MessageSender.MAX_MESSAGE_SIZE + 1)), + Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId), + Optional.empty(), + MessageTooLargeException.class) + ); + } + + private static MessageProtos.Envelope generateEnvelope(@Nullable ServiceIdentifier sourceIdentifier, final int contentLength) { + final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() + .setContent(ByteString.copyFrom(TestRandomUtil.nextBytes(contentLength))); + + if (sourceIdentifier != null) { + envelopeBuilder.setSourceServiceId(sourceIdentifier.toServiceIdentifierString()); + } + + return envelopeBuilder.build(); + } + @Test void validateContentLength() { assertThrows(MessageTooLargeException.class, () -> - MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE + 1, false, false, false, null)); + MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE + 1, false, false, false, Tag.of(UserAgentTagUtil.PLATFORM_TAG, "test"))); assertDoesNotThrow(() -> - MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null)); + MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, Tag.of(UserAgentTagUtil.PLATFORM_TAG, "test"))); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index ac67362eb..5233f1f0b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -4,7 +4,6 @@ */ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.argThat; @@ -14,6 +13,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.protobuf.ByteString; import java.time.Instant; import java.util.Collections; @@ -21,11 +21,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; @@ -36,6 +34,7 @@ import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.TestClock; @@ -61,154 +60,136 @@ public class ChangeNumberManagerTest { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); - final UUID uuid = account.getUuid(); + final UUID uuid = account.getIdentifier(IdentityType.ACI); final List devices = account.getDevices(); final UUID updatedPni = UUID.randomUUID(); updatedPhoneNumberIdentifiersByAccount.put(account, updatedPni); final Account updatedAccount = mock(Account.class); - when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(uuid); + when(updatedAccount.getIdentifier(IdentityType.PNI)).thenReturn(updatedPni); + when(updatedAccount.isIdentifiedBy(any())).thenReturn(false); + when(updatedAccount.isIdentifiedBy(new AciServiceIdentifier(uuid))).thenReturn(true); + when(updatedAccount.isIdentifiedBy(new PniServiceIdentifier(updatedPni))).thenReturn(true); when(updatedAccount.getNumber()).thenReturn(number); - when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni); when(updatedAccount.getDevices()).thenReturn(devices); - for (byte i = 1; i <= 3; i++) { - final Optional d = account.getDevice(i); - when(updatedAccount.getDevice(i)).thenReturn(d); - } + when(updatedAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + + account.getDevices().forEach(device -> + when(updatedAccount.getDevice(device.getId())).thenReturn(Optional.of(device))); return updatedAccount; }); } @Test - void changeNumberSetPrimaryDevicePrekey() throws Exception { - Account account = mock(Account.class); - when(account.getNumber()).thenReturn("+18005551234"); + void changeNumberSingleDevice() throws Exception { + final String targetNumber = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(Device.PRIMARY_ID, - KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); - changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); - verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); + final Map ecSignedPreKeys = + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); + + final Map kemLastResortPreKeys = + Map.of(Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(2, pniIdentityKeyPair)); + + final UUID accountIdentifier = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.isIdentifiedBy(any())).thenReturn(false); + when(account.isIdentifiedBy(new AciServiceIdentifier(accountIdentifier))).thenReturn(true); + + changeNumberManager.changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyList(), Collections.emptyMap(), null); + verify(accountsManager).changeNumber(account, targetNumber, pniIdentityKey, ecSignedPreKeys, kemLastResortPreKeys, Collections.emptyMap()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); } @Test - void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { - final String originalE164 = "+18005551234"; - final String changedE164 = "+18025551234"; - final UUID aci = UUID.randomUUID(); - final UUID pni = UUID.randomUUID(); + void changeNumberLinkedDevices() throws Exception { + final String targetNumber = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); - final Account account = mock(Account.class); - when(account.getNumber()).thenReturn(originalE164); - when(account.getUuid()).thenReturn(aci); - when(account.getPhoneNumberIdentifier()).thenReturn(pni); + final UUID aci = UUID.randomUUID(); + + final byte primaryDeviceId = Device.PRIMARY_ID; + final byte linkedDeviceId = primaryDeviceId + 1; + + final int primaryDeviceRegistrationId = 17; + final int linkedDeviceRegistrationId = primaryDeviceRegistrationId + 1; final Device primaryDevice = mock(Device.class); - when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(7); + when(primaryDevice.getId()).thenReturn(primaryDeviceId); + when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryDeviceRegistrationId); final Device linkedDevice = mock(Device.class); - final byte linkedDeviceId = Device.PRIMARY_ID + 1; - final int linkedDeviceRegistrationId = 17; when(linkedDevice.getId()).thenReturn(linkedDeviceId); when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceRegistrationId); + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(aci); + when(account.isIdentifiedBy(any())).thenReturn(false); + when(account.isIdentifiedBy(new AciServiceIdentifier(aci))).thenReturn(true); when(account.getDevice(anyByte())).thenReturn(Optional.empty()); - when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice)); + when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice)); when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice)); when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(Device.PRIMARY_ID, - KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + final Map ecSignedPreKeys = Map.of( + primaryDeviceId, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19); - final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.type()).thenReturn(1); - when(msg.destinationDeviceId()).thenReturn(linkedDeviceId); - when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId); - when(msg.content()).thenReturn(new byte[]{1}); + final Map kemLastResortPreKeys = Map.of( + primaryDeviceId, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), + linkedDeviceId, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); + final Map registrationIds = Map.of( + primaryDeviceId, primaryDeviceRegistrationId, + linkedDeviceId, linkedDeviceRegistrationId); - verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); + final IncomingMessage incomingMessage = + new IncomingMessage(1, linkedDeviceId, linkedDeviceRegistrationId, new byte[] { 1 }); + + changeNumberManager.changeNumber(account, + targetNumber, + pniIdentityKey, + ecSignedPreKeys, + kemLastResortPreKeys, + List.of(incomingMessage), + registrationIds, + null); + + verify(accountsManager).changeNumber(account, + targetNumber, + pniIdentityKey, + ecSignedPreKeys, + kemLastResortPreKeys, + registrationIds); final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder() - .setType(MessageProtos.Envelope.Type.forNumber(msg.type())) + .setType(MessageProtos.Envelope.Type.forNumber(incomingMessage.type())) .setClientTimestamp(CLOCK.millis()) .setServerTimestamp(CLOCK.millis()) .setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString()) - .setContent(ByteString.copyFrom(msg.content())) + .setContent(ByteString.copyFrom(incomingMessage.content())) .setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString()) - .setSourceDevice(Device.PRIMARY_ID) + .setSourceDevice(primaryDeviceId) .setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString()) .setUrgent(true) .setEphemeral(false) + .setStory(false) .build(); - verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)), + verify(messageSender).sendMessages(argThat(a -> a.getIdentifier(IdentityType.ACI).equals(aci)), eq(new AciServiceIdentifier(aci)), eq(Map.of(linkedDeviceId, expectedEnvelope)), eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)), - eq(Optional.of(Device.PRIMARY_ID)), + eq(Optional.of(primaryDeviceId)), any()); } - - @Test - void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { - final String originalE164 = "+18005551234"; - final String changedE164 = "+18025551234"; - final UUID aci = UUID.randomUUID(); - final UUID pni = UUID.randomUUID(); - - final Account account = mock(Account.class); - when(account.getNumber()).thenReturn(originalE164); - when(account.getUuid()).thenReturn(aci); - when(account.getPhoneNumberIdentifier()).thenReturn(pni); - - final Device d2 = mock(Device.class); - final byte deviceId2 = 2; - when(d2.getId()).thenReturn(deviceId2); - - when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); - when(account.getDevices()).thenReturn(List.of(d2)); - - final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(Device.PRIMARY_ID, - KeysHelper.signedECPreKey(1, pniIdentityKeyPair), - deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), - (byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); - - final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(deviceId2); - when(msg.content()).thenReturn(new byte[]{1}); - - changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds, null); - - verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds); - - @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = - ArgumentCaptor.forClass(Map.class); - - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); - - assertEquals(1, envelopeCaptor.getValue().size()); - assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); - - final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); - - assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); - assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); - assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); - assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); - } }